Posted on

Exploratory Data Analysis – Visualization using Matplotlib and Seaborn – with Python code

Exploratory Data Analysis - Visualization using Matplotlib and Seaborn - with Python code

This tutorial covers some basic usage patterns and best practices to help you get started with Matplotlib.

import matplotlib as mpl

import matplotlib.pyplot as plt

import numpy as np

A simple example

Matplotlib graphs your data on Figures (e.g., windows, Jupyter widgets, etc.), each of which can contain one or more Axes, an area where points can be specified in terms of x-y coordinates (or theta-r in a polar plot, x-y-z in a 3D plot, etc). The simplest way of creating a Figure with an Axes is using pyplot.subplots. We can then use Axes.plot to draw some data on the Axes:

fig, ax = plt.subplots()  # Create a figure containing a single axes.

ax.plot([1, 2, 3, 4], [1, 4, 2, 3]);  # Plot some data on the axes.

Parts of a Matplotlib Figure

Here are the components of a Matplotlib Figure.



The whole figure. The Figure keeps track of all the child Axes, a group of ‘special’ Artists (titles, figure legends, colorbars, etc), and even nested subfigures.

The easiest way to create a new Figure is with pyplot:

fig = plt.figure()  # an empty figure with no Axes

fig, ax = plt.subplots()  # a figure with a single Axes

fig, axs = plt.subplots(2, 2)  # a figure with a 2x2 grid of Axes

It is often convenient to create the Axes together with the Figure, but you can also manually add Axes later on. Note that many Matplotlib backends support zooming and panning on figure windows.


An Axes is an Artist attached to a Figure that contains a region for plotting data, and usually includes two (or three in the case of 3D) Axis objects (be aware of the difference between Axes and Axis) that provide ticks and tick labels to provide scales for the data in the Axes. Each Axes also has a title (set via set_title()), an x-label (set via set_xlabel()), and a y-label set via set_ylabel()).

The Axes class and its member functions are the primary entry point to working with the OOP interface, and have most of the plotting methods defined on them (e.g. ax.plot(), shown above, uses the plot method)


These objects set the scale and limits and generate ticks (the marks on the Axis) and ticklabels (strings labeling the ticks). The location of the ticks is determined by a Locator object and the ticklabel strings are formatted by a Formatter. The combination of the correct Locator and Formatter gives very fine control over the tick locations and labels.


Basically, everything visible on the Figure is an Artist (even Figure, Axes, and Axis objects). This includes Text objects, Line2D objects, collections objects, Patch objects, etc. When the Figure is rendered, all of the Artists are drawn to the canvas. Most Artists are tied to an Axes; such an Artist cannot be shared by multiple Axes, or moved from one to another.

Types of inputs to Matplotlib plotting functions

Plotting functions expect numpy.array or as input, or objects that can be passed to numpy.asarray. Classes that are similar to arrays (‘array-like’) such as pandas data objects and numpy.matrix may not work as intended. Common convention is to convert these to numpy.array objects prior to plotting. For example, to convert a numpy.matrix

b = np.matrix([[1, 2], [3, 4]])

b_asarray = np.asarray(b)

Most methods will also parse an addressable object like a dict, a numpy.recarray, or a pandas.DataFrame. Matplotlib allows you provide the data keyword argument and generate plots passing the strings corresponding to the x and y variables.

np.random.seed(19680801)  # seed the random number generator.

data = {'a': np.arange(50),
        'c': np.random.randint(0, 50, 50),
        'd': np.random.randn(50)}

data['b'] = data['a'] + 10 * np.random.randn(50)

data['d'] = np.abs(data['d']) * 100

fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')

ax.scatter('a', 'b', c='c', s='d', data=data)

ax.set_xlabel('entry a')

ax.set_ylabel('entry b')

Coding styles – The object-oriented and the pyplot interfaces

As noted above, there are essentially two ways to use Matplotlib:

  • Explicitly create Figures and Axes, and call methods on them (the “object-oriented (OO) style”).
  • Rely on pyplot to automatically create and manage the Figures and Axes, and use pyplot functions for plotting.

Object Oriented Interface

So one can use the OO-style:

x = np.linspace(0, 2, 100)  # Sample data.

# Note that even in the OO-style, we use `.pyplot.figure` to create the Figure.

fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')

ax.plot(x, x, label='linear')  # Plot some data on the axes.

ax.plot(x, x**2, label='quadratic')  # Plot more data on the axes...

ax.plot(x, x**3, label='cubic')  # ... and some more.

ax.set_xlabel('x label')  # Add an x-label to the axes.

ax.set_ylabel('y label')  # Add a y-label to the axes.

ax.set_title("Simple Plot")  # Add a title to the axes.

ax.legend();  # Add a legend.
Simple Plot

Pyplot Interface

x = np.linspace(0, 2, 100)  # Sample data.

plt.figure(figsize=(5, 2.7), layout='constrained')

plt.plot(x, x, label='linear')  # Plot some data on the (implicit) axes.

plt.plot(x, x**2, label='quadratic')  # etc.

plt.plot(x, x**3, label='cubic')

plt.xlabel('x label')

plt.ylabel('y label')

plt.title("Simple Plot")

Simple Plot

Matplotlib’s documentation and examples use both the OO and the pyplot styles. In general, we suggest using the OO style, particularly for complicated plots, and functions and scripts that are intended to be reused as part of a larger project. However, the pyplot style can be very convenient for quick interactive work.

Helper functions in Matplotlib

If you need to make the same plots over and over again with different data sets, or want to easily wrap Matplotlib methods, use the recommended signature function below.

def my_plotter(ax, data1, data2, param_dict):

    A helper function to make a graph.

    out = ax.plot(data1, data2, **param_dict)

    return out

data1, data2, data3, data4 = np.random.randn(4, 100)  # make 4 random data sets

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 2.7))

# use the helper function twice to populate two subplots:

my_plotter(ax1, data1, data2, {'marker': 'x'})

my_plotter(ax2, data3, data4, {'marker': 'o'});

Styling Artists

Most plotting methods have styling options for the Artists, accessible either when a plotting method is called, or from a “setter” on the Artist. In the plot below we manually set the color, linewidth, and linestyle of the Artists created by plot, and we set the linestyle of the second line after the fact with set_linestyle.

fig, ax = plt.subplots(figsize=(5, 2.7))

x = np.arange(len(data1))

ax.plot(x, np.cumsum(data1), color='blue', linewidth=3, linestyle='--')

l, = ax.plot(x, np.cumsum(data2), color='orange', linewidth=2)


Color Styles

Matplotlib has a very flexible array of colors that are accepted for most Artists; see the colors tutorial for a list of specifications. Some Artists will take multiple colors. i.e. for a scatter plot, the edge of the markers can be different colors from the interior:

fig, ax = plt.subplots(figsize=(5, 2.7))

ax.scatter(data1, data2, s=50, facecolor='C0', edgecolor='k')

Linewidths, linestyles, and markersizes styles

Line widths are typically in typographic points (1 pt = 1/72 inch) and available for Artists that have stroked lines. Similarly, stroked lines can have a linestyle. See the linestyles example.

Marker size depends on the method being used. plot specifies markersize in points, and is generally the “diameter” or width of the marker. scatter specifies markersize as approximately proportional to the visual area of the marker. There is an array of markerstyles available as string codes (see markers), or users can define their own MarkerStyle (see Marker reference):

fig, ax = plt.subplots(figsize=(5, 2.7))

ax.plot(data1, 'o', label='data1')

ax.plot(data2, 'd', label='data2')

ax.plot(data3, 'v', label='data3')

ax.plot(data4, 's', label='data4')


Labelling Matplotlib plots

Axes labels and text

set_xlabel, set_ylabel, and set_title are used to add text in the indicated locations (see Text in Matplotlib Plots for more discussion). Text can also be directly added to plots using text:

mu, sigma = 115, 15

x = mu + sigma * np.random.randn(10000)

fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')

# the histogram of the data

n, bins, patches = ax.hist(x, 50, density=1, facecolor='C0', alpha=0.75)

ax.set_xlabel('Length [cm]')


ax.set_title('Aardvark lengths\n (not really)')

ax.text(75, .025, r'$\mu=115,\ \sigma=15$')

ax.axis([55, 175, 0, 0.03])

Aardvark lengths  (not really)

All of the text functions return a matplotlib.text.Text instance. Just as with lines above, you can customize the properties by passing keyword arguments into the text functions:

t = ax.set_xlabel('my data', fontsize=14, color='red')

Using mathematical expressions in text

Matplotlib accepts TeX equation expressions in any text expression. For example to write the expression 


 in the title, you can write a TeX expression surrounded by dollar signs:


where the r preceding the title string signifies that the string is a raw string and not to treat backslashes as python escapes. Matplotlib has a built-in TeX expression parser and layout engine, and ships its own math fonts – for details see Writing mathematical expressions. You can also use LaTeX directly to format your text and incorporate the output directly into your display figures or saved postscript.

Annotating your Matplotlib charts

We can also annotate points on a plot, often by connecting an arrow pointing to xy, to a piece of text at xy text:

fig, ax = plt.subplots(figsize=(5, 2.7))

t = np.arange(0.0, 5.0, 0.01)

s = np.cos(2 * np.pi * t)

line, = ax.plot(t, s, lw=2)

ax.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(facecolor='black', shrink=0.05))

ax.set_ylim(-2, 2)


Often we want to identify lines or markers with a Axes.legend:

fig, ax = plt.subplots(figsize=(5, 2.7))

ax.plot(np.arange(len(data1)), data1, label='data1')

ax.plot(np.arange(len(data2)), data2, label='data2')

ax.plot(np.arange(len(data3)), data3, 'd', label='data3')


Legends in Matplotlib are quite flexible in layout, placement, and what Artists they can represent. 

X Axis and Y Axis scales and ticks

Each Axes has two (or three) Axis objects representing the x- and y-axis. These control the scale of the Axis, the tick locators and the tick formatters. Additional Axes can be attached to display further Axis objects.


In addition to the linear scale, Matplotlib supplies non-linear scales, such as a log-scale. Since log-scales are used so much there are also direct methods like loglog, semilogx, and semilogy. There are a number of scales (see Scales for other examples). Here we set the scale manually:

fig, axs = plt.subplots(1, 2, figsize=(5, 2.7), layout='constrained')

xdata = np.arange(len(data1))  # make an ordinal for this

data = 10**data1

axs[0].plot(xdata, data)


axs[1].plot(xdata, data)

The scale sets the mapping from data values to spacing along the Axis. This happens in both directions, and gets combined into a transform, which is the way that Matplotlib maps from data coordinates to Axes, Figure, or screen coordinates. 

Tick locators and formatters

Each Axis has a tick locator and formatter that choose where along the Axis objects to put tick marks. A simple interface to this is set_xticks:

fig, axs = plt.subplots(2, 1, layout='constrained')

axs[0].plot(xdata, data1)

axs[0].set_title('Automatic ticks')

axs[1].plot(xdata, data1)

axs[1].set_xticks(np.arange(0, 100, 30), ['zero', '30', 'sixty', '90'])

axs[1].set_yticks([-1.5, 0, 1.5])  # note that we don't need to specify labels

axs[1].set_title('Manual ticks')
Automatic ticks, Manual ticks

Different scales can have different locators and formatters; for instance the log-scale above uses LogLocator and LogFormatter

Plotting dates and strings

Matplotlib can handle plotting arrays of dates and arrays of strings, as well as floating point numbers. These get special locators and formatters as appropriate. For dates:

fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')

dates = np.arange(np.datetime64('2021-11-15'), np.datetime64('2021-12-25'),
                  np.timedelta64(1, 'h'))

data = np.cumsum(np.random.randn(len(dates)))

ax.plot(dates, data)

cdf = mpl.dates.ConciseDateFormatter(ax.xaxis.get_major_locator())


For strings, we get categorical plotting.

fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')

categories = ['turnips', 'rutabaga', 'cucumber', 'pumpkins'], np.random.rand(len(categories)))

One caveat about categorical plotting is that some methods of parsing text files return a list of strings, even if the strings all represent numbers or dates. If you pass 1000 strings, Matplotlib will think you meant 1000 categories and will add 1000 ticks to your plot!

Additional Axis objects

Plotting data of different magnitude in one chart may require an additional y-axis. Such an Axis can be created by using twinx to add a new Axes with an invisible x-axis and a y-axis positioned at the right (analogously for twiny). See Plots with different scales for another example.

Similarly, you can add a secondary_xaxis or secondary_yaxis having a different scale than the main Axis to represent the data in different scales or units. See Secondary Axis for further examples.

fig, (ax1, ax3) = plt.subplots(1, 2, figsize=(7, 2.7), layout='constrained')

l1, = ax1.plot(t, s)

ax2 = ax1.twinx()

l2, = ax2.plot(t, range(len(t)), 'C1')

ax2.legend([l1, l2], ['Sine (left)', 'Straight (right)'])

ax3.plot(t, s)

ax3.set_xlabel('Angle [rad]')

ax4 = ax3.secondary_xaxis('top', functions=(np.rad2deg, np.deg2rad))

ax4.set_xlabel('Angle [°]')

Color mapped data

Often we want to have a third dimension in a plot represented by a colors in a colormap. Matplotlib has a number of plot types that do this:

X, Y = np.meshgrid(np.linspace(-3, 3, 128), np.linspace(-3, 3, 128))

Z = (1 - X/2 + X**5 + Y**3) * np.exp(-X**2 - Y**2)

fig, axs = plt.subplots(2, 2, layout='constrained')

pc = axs[0, 0].pcolormesh(X, Y, Z, vmin=-1, vmax=1, cmap='RdBu_r')

fig.colorbar(pc, ax=axs[0, 0])

axs[0, 0].set_title('pcolormesh()')

co = axs[0, 1].contourf(X, Y, Z, levels=np.linspace(-1.25, 1.25, 11))

fig.colorbar(co, ax=axs[0, 1])

axs[0, 1].set_title('contourf()')

pc = axs[1, 0].imshow(Z**2 * 100, cmap='plasma',
                          norm=mpl.colors.LogNorm(vmin=0.01, vmax=100))

fig.colorbar(pc, ax=axs[1, 0], extend='both')

axs[1, 0].set_title('imshow() with LogNorm()')

pc = axs[1, 1].scatter(data1, data2, c=data3, cmap='RdBu_r')

fig.colorbar(pc, ax=axs[1, 1], extend='both')

axs[1, 1].set_title('scatter()')
pcolormesh(), contourf(), imshow() with LogNorm(), scatter()

Working with multiple Figures and Axes

You can open multiple Figures with multiple calls to fig = plt.figure() or fig2, ax = plt.subplots(). By keeping the object references you can add Artists to either Figure.

Multiple Axes can be added a number of ways, but the most basic is plt.subplots() as used above. One can achieve more complex layouts, with Axes objects spanning columns or rows, using subplot_mosaic.

fig, axd = plt.subplot_mosaic([['upleft', 'right'],
                               ['lowleft', 'right']], layout='constrained')



upleft, right, lowleft

Learn to work with Seaborn code – Basic Usage

Most of your interactions with seaborn will happen through a set of plotting functions. Later chapters in the tutorial will explore the specific features offered by each function. This chapter will introduce, at a high-level, the different kinds of functions that you will encounter.

The Seaborn package can be imported as follows:

import seaborn as sns

Similar functions for similar tasks

The seaborn namespace is flat; all of the functionality is accessible at the top level. But the code itself is hierarchically structured, with modules of functions that achieve similar visualization goals through different means. Most of the docs are structured around these modules: you’ll encounter names like “relational”, “distributional”, and “categorical”.

Histogram in Seaborn

For example, the distributions module defines functions that specialize in representing the distribution of datapoints. This includes familiar methods like the histogram:

penguins = sns.load_dataset("penguins")

sns.histplot(data=penguins, x="flipper_length_mm", hue="species", multiple="stack")

Kernel density estimation in Seaborn

Along with similar, but perhaps less familiar, options such as kernel density estimation:

sns.kdeplot(data=penguins, x="flipper_length_mm", hue="species", multiple="stack")

Functions within a module share a lot of underlying code and offer similar features that may not be present in other components of the library (such as multiple=”stack” in the examples above). They are designed to facilitate switching between different visual representations as you explore a dataset, because different representations often have complementary strengths and weaknesses.

Figure-level vs. axes-level functions in Seaborn

In addition to the different modules, there is a cross-cutting classification of seaborn functions as “axes-level” or “figure-level”. The examples above are axes-level functions. They plot data onto a single matplotlib.pyplot.Axes object, which is the return value of the function.

In contrast, figure-level functions interface with matplotlib through a seaborn object, usually a FacetGrid, that manages the figure. Each module has a single figure-level function, which offers a unitary interface to its various axes-level functions. The organization looks a bit like this:


Displot in Seaborn

For example, displot() is the figure-level function for the distributions module. Its default behavior is to draw a histogram, using the same code as histplot() behind the scenes:

sns.displot(data=penguins, x="flipper_length_mm", hue="species", multiple="stack")

To draw a kernel density plot instead, using the same code as kdeplot(), select it using the kind parameter:

sns.displot(data=penguins, x="flipper_length_mm", hue="species", multiple="stack", kind="kde")

You’ll notice that the figure-level plots look mostly like their axes-level counterparts, but there are a few differences. Notably, the legend is placed ouside the plot. They also have a slightly different shape (more on that shortly).

The most useful feature offered by the figure-level functions is that they can easily create figures with multiple subplots. For example, instead of stacking the three distributions for each species of penguins in the same axes, we can “facet” them by plotting each distribution across the columns of the figure:

sns.displot(data=penguins, x="flipper_length_mm", hue="species", col="species")

The figure-level functions wrap their axes-level counterparts and pass the kind-specific keyword arguments (such as the bin size for a histogram) down to the underlying function. That means they are no less flexible, but there is a downside: the kind-specific parameters don’t appear in the function signature or docstrings. Some of their features might be less discoverable, and you may need to look at two different pages of the documentation before understanding how to achieve a specific goal.

Axes-level functions make self-contained plots in Seaborn

The axes-level functions are written to act like drop-in replacements for matplotlib functions. While they add axis labels and legends automatically, they don’t modify anything beyond the axes that they are drawn into. That means they can be composed into arbitrarily-complex matplotlib figures with predictable results.

The axes-level functions call matplotlib.pyplot.gca() internally, which hooks into the matplotlib state-machine interface so that they draw their plots on the “currently-active” axes. But they additionally accept an ax= argument, which integrates with the object-oriented interface and lets you specify exactly where each plot should go:

f, axs = plt.subplots(1, 2, figsize=(8, 4), gridspec_kw=dict(width_ratios=[4, 3]))

sns.scatterplot(data=penguins, x="flipper_length_mm", y="bill_length_mm", hue="species", ax=axs[0])

sns.histplot(data=penguins, x="species", hue="species", shrink=.8, alpha=.8, legend=False, ax=axs[1])


Figure-level functions own their figure in Seaborn

In contrast, figure-level functions cannot (easily) be composed with other plots. By design, they “own” their own figure, including its initialization, so there’s no notion of using a figure-level function to draw a plot onto an existing axes. This constraint allows the figure-level functions to implement features such as putting the legend outside of the plot.

Nevertheless, it is possible to go beyond what the figure-level functions offer by accessing the matplotlib axes on the object that they return and adding other elements to the plot that way:

tips = sns.load_dataset("tips")

g = sns.relplot(data=tips, x="total_bill", y="tip"), 2), slope=.2, color="b", dashes=(5, 2))

Customizing plots from a figure-level function in Seaborn

The figure-level functions return a FacetGrid instance, which has a few methods for customizing attributes of the plot in a way that is “smart” about the subplot organization. For example, you can change the labels on the external axes using a single line of code:

g = sns.relplot(data=penguins, x="flipper_length_mm", y="bill_length_mm", col="sex")

g.set_axis_labels("Flipper length (mm)", "Bill length (mm)")

While convenient, this does add a bit of extra complexity, as you need to remember that this method is not part of the matplotlib API and exists only when using a figure-level function.

Specifying figure sizes in Seaborn

To increase or decrease the size of a matplotlib plot, you set the width and height of the entire figure, either in the global rcParams, while setting up the plot (e.g. with the figsize parameter of matplotlib.pyplot.subplots()), or by calling a method on the figure object (e.g. matplotlib.Figure.set_size_inches()). When using an axes-level function in seaborn, the same rules apply: the size of the plot is determined by the size of the figure it is part of and the axes layout in that figure.

When using a figure-level function, there are several key differences. First, the functions themselves have parameters to control the figure size (although these are actually parameters of the underlying FacetGrid that manages the figure). Second, these parameters, height and aspect, parameterize the size slightly differently than the width, height parameterization in matplotlib (using the seaborn parameters, width = height * apsect). Most importantly, the parameters correspond to the size of each subplot, rather than the size of the overall figure.

To illustrate the difference between these approaches, here is the default output of matplotlib.pyplot.subplots() with one subplot:

f, ax = plt.subplots()

A figure with multiple columns will have the same overall size, but the axes will be squeezed horizontally to fit in the space:

f, ax = plt.subplots(1, 2, sharey=True)

Facetgrid in Seaborn

In contrast, a plot created by a figure-level function will be square. To demonstrate that, let’s set up an empty plot by using FacetGrid directly. This happens behind the scenes in functions like relplot(), displot(), or catplot():

g = sns.FacetGrid(penguins)

When additional columns are added, the figure itself will become wider, so that its subplots have the same size and shape:

g = sns.FacetGrid(penguins, col="sex")

And you can adjust the size and shape of each subplot without accounting for the total number of rows and columns in the figure:

g = sns.FacetGrid(penguins, col="sex", height=3.5, aspect=.75)

The upshot is that you can assign faceting variables without stopping to think about how you’ll need to adjust the total figure size. A downside is that, when you do want to change the figure size, you’ll need to remember that things work a bit differently than they do in matplotlib.

Relative merits of figure-level functions in Seaborn

Here is a summary of the pros and cons that we have discussed above:

Easy faceting by data variablesMany parameters not in function signature
Legend outside of plot by defaultCannot be part of a larger matplotlib figure
Easy figure-level customizationDifferent API from matplotlib

On balance, the figure-level functions add some additional complexity that can make things more confusing for beginners, but their distinct features give them additional power. The tutorial documentation mostly uses the figure-level functions, because they produce slightly cleaner plots, and we generally recommend their use for most applications. The one situation where they are not a good choice is when you need to make a complex, standalone figure that composes multiple different plot kinds. At this point, it’s recommended to set up the figure using matplotlib directly and to fill in the individual components using axes-level functions.

Combining multiple views on the data -Sample exploratory data analysis

Two important plotting functions in seaborn don’t fit cleanly into the classification scheme discussed above. These functions, jointplot() and pairplot(), employ multiple kinds of plots from different modules to represent multiple aspects of a dataset in a single figure. Both plots are figure-level functions and create figures with multiple subplots by default. But they use different objects to manage the figure: JointGrid and PairGrid, respectively.


jointplot() plots the relationship or joint distribution of two variables while adding marginal axes that show the univariate distribution of each one separately:

sns.jointplot(data=penguins, x="flipper_length_mm", y="bill_length_mm", hue="species")


pairplot() is similar — it combines joint and marginal views — but rather than focusing on a single relationship, it visualizes every pairwise combination of variables simultaneously:

sns.pairplot(data=penguins, hue="species")


Behind the scenes, these functions are using axes-level functions that you have already met (scatterplot() and kdeplot()), and they also have a kind parameter that lets you quickly swap in a different representation:

sns.jointplot(data=penguins, x="flipper_length_mm", y="bill_length_mm", hue="species", kind="hist")

Reference and further reading:

Matplotlib Cheatsheet:

Seaborn Cheatsheet:

Posted on

Decision Tree Algorithm in Machine Learning: Concepts, Techniques, and Python Scikit Learn Example

decision tree algorithm concepts using scikit-learn in python

Machine learning is a subfield of artificial intelligence that involves the development of algorithms that can learn from data and make predictions or decisions based on patterns learned from the data. Decision trees are one of the most widely used and interpretable machine learning algorithms that can be used for both classification and regression tasks. They are particularly popular in fields such as finance, healthcare, marketing, and customer analytics due to their ability to provide understandable and transparent models.

In this article, we will provide a comprehensive overview of decision trees, covering their concepts, techniques, and practical implementation using Python. We will start by explaining the basic concepts of decision trees, including tree structure, node types, and decision rules. We will then delve into the techniques for constructing decision trees, such as entropy, information gain, and Gini impurity, as well as tree pruning methods for improving model performance. Next, we will discuss feature selection techniques in decision trees, including splitting rules, attribute selection measures, and handling missing values. Finally, we will explore methods for interpreting decision tree models, including model visualization, feature importance analysis, and model explanation.

Important decision tree concepts

Decision trees are tree-like structures that represent decision-making processes or decisions based on the input features. They consist of nodes, edges, and leaves, where nodes represent decision points, edges represent decisions or outcomes, and leaves represent the final prediction or decision. Each node in a decision tree corresponds to a feature or attribute, and the tree is constructed recursively by splitting the data based on the values of the features until a decision or prediction is reached.

Elements of a Decision Tree Algorithm
Elements of a Decision Tree

There are several important concepts to understand in decision trees:

  1. Root Node: The topmost node in a decision tree, also known as the root node, represents the feature that provides the best split of the data based on a selected splitting criterion.
  2. Internal Nodes: Internal nodes in a decision tree represent decision points where the data is split into different branches based on the feature values. Internal nodes contain decision rules that determine the splitting criterion and the branching direction.
  3. Leaf Nodes: Leaf nodes in a decision tree represent the final decision or prediction. They do not have any outgoing edges and provide the output or prediction for the input data based on the majority class or mean/median value, depending on whether it’s a classification or regression problem.
  4. Decision Rules: Decision rules in a decision tree are determined based on the selected splitting criterion, which measures the impurity or randomness of the data. The decision rule at each node determines the feature value that is used to split the data into different branches.
  5. Impurity Measures: Impurity measures are used to determine the splitting criterion in decision trees. Common impurity measures include entropy, information gain, and Gini impurity. These measures quantify the randomness or impurity of the data at each node, and the split that minimizes the impurity is selected as the splitting criterion.

Become a Machine Learning Engineer with Experience and implement decision trees in production environments

Decision Tree Construction Techniques

The process of constructing a decision tree involves recursively splitting the data based on the values of the features until a stopping criterion is met. There are several techniques for constructing decision trees, including entropy, information gain, and Gini impurity.


Entropy is a measure of the randomness or impurity of the data at a node in a decision tree. It is defined as the sum of the negative logarithm of the probabilities of all classes in the data, multiplied by their probabilities. The formula for entropy is given as:

Entropy = – Σ p(i) * log2(p(i))

where p(i) is the probability of class i in the data at a node. The goal of entropy-based decision tree construction is to minimize the entropy or maximize the information gain at each split, which leads to a more pure and accurate decision tree.

Information Gain

Information gain is another commonly used criterion for decision tree construction. It measures the reduction in entropy or increase in information at a node after a particular split. Information gain is calculated as the difference between the entropy of the parent node and the weighted average of the entropies of the child nodes after the split. The formula for information gain is given as:

Information Gain = Entropy(parent) – Σ (|Sv|/|S|) * Entropy(Sv)

where Sv is the subset of data after the split based on a particular feature value, and |S| and |Sv| are the total number of samples in the parent node and the subset Sv, respectively. The decision rule that leads to the highest information gain is selected as the splitting criterion.

Gini Impurity

Gini impurity is another impurity measure used in decision tree construction. It measures the probability of misclassification of a randomly chosen sample at a node. The formula for Gini impurity is given as:

Gini Impurity = 1 – Σ p(i)^2

where p(i) is the probability of class i in the data at a node. Similar to entropy and information gain, the goal of Gini impurity-based decision tree construction is to minimize the Gini impurity or maximize the Gini gain at each split.

Become a Machine Learning Engineer with Experience and implement decision trees in production environments

Decision Trees in Python Scikit-Learn (sklearn)

Python provides several libraries for implementing decision trees, such as scikit-learn, XGBoost, and LightGBM. Here, we will illustrate an example of decision tree classifier implementation using scikit-learn, one of the most popular machine learning libraries in Python.

Download the dataset here: Iris dataset uci | Kaggle

# Import the required libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Load the dataset
data = pd.read_csv('iris.csv')  # Load the iris dataset

# Split the dataset into features and labels
X = data.iloc[:, :-1]  # Features
y = data.iloc[:, -1]  # Labels

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize the decision tree classifier
clf = DecisionTreeClassifier()

# Train the decision tree classifier, y_train)

# Make predictions on the test set
y_pred = clf.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: {:.2f}%".format(accuracy * 100))

In this example, we load the popular Iris dataset, split it into features (X) and labels (y), and then split it into training and testing sets using the train_test_split function from scikit-learn. We then initialize a decision tree classifier using the DecisionTreeClassifier class from scikit-learn, fit the classifier to the training data using the fit method, and make predictions on the test data using the predict method. Finally, we calculate the accuracy of the decision tree classifier using the accuracy_score function from scikit-learn.

Become a Machine Learning Engineer with Experience and implement decision trees in production environments

Overfitting in Decision Trees and how to prevent overfitting

Overfitting is a common problem in decision trees where the model becomes too complex and captures noise instead of the underlying patterns in the data. As a result, the tree performs well on the training data but poorly on new, unseen data.

To prevent overfitting in decision trees, we can use the following techniques:

Use more data to prevent overfitting

Overfitting can occur when a model is trained on a limited amount of data, causing it to capture noise rather than the underlying patterns. Collecting more data can help the model generalize better, reducing the likelihood of overfitting.

  • Collect more data from various sources
  • Use data augmentation techniques to create synthetic data

Set a minimum number of samples for each leaf node

A leaf node is a terminal node in a decision tree that contains the final classification decision. Setting a minimum number of samples for each leaf node can help prevent the model from splitting the data too finely, which can lead to overfitting.

from sklearn.tree import DecisionTreeClassifier
dtc = DecisionTreeClassifier(min_samples_leaf=5)

Prune and visualize the decision tree

Decision trees are prone to overfitting, which means they can become too complex and fit the training data too closely, resulting in poor generalization performance on unseen data. Pruning is a technique used to prevent overfitting by removing unnecessary branches or nodes from a decision tree.


Pre-pruning is a pruning technique that involves stopping the tree construction process before it reaches its maximum depth or minimum number of samples per leaf. This prevents the tree from becoming too deep or too complex, and helps in creating a simpler and more interpretable decision tree. Pre-pruning can be done by setting a maximum depth for the tree, a minimum number of samples per leaf, or a maximum number of leaf nodes.

from sklearn.tree import DecisionTreeClassifier

# Set the maximum depth for the tree
max_depth = 5

# Set the minimum number of samples per leaf
min_samples_leaf = 10

# Create a decision tree classifier with pre-pruning
clf = DecisionTreeClassifier(max_depth=max_depth, min_samples_leaf=min_samples_leaf)

# Fit the model on the training data, y_train)

# Evaluate the model on the test data
y_pred = clf.predict(X_test)


Post-pruning is a pruning technique that involves constructing the decision tree to its maximum depth or allowing it to overfit the training data, and then pruning back the unnecessary branches or nodes. This is done by evaluating the performance of the tree on a validation set or using a pruning criterion such as cost-complexity pruning. Cost-complexity pruning involves calculating the cost of adding a new node or branch to the tree, and pruning back the nodes or branches that do not improve the performance significantly.

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text

# Create a decision tree classifier without pruning
clf = DecisionTreeClassifier()

# Fit the model on the training data, y_train)

# Evaluate the model on the validation data. This is the baseline score
score = clf.score(X_val, y_val)

# Print the decision tree before pruning

# Prune the decision tree using cost-complexity pruning
ccp_alphas = clf.cost_complexity_pruning_path(X_train, y_train).ccp_alphas
for ccp_alpha in ccp_alphas:
    pruned_clf = DecisionTreeClassifier(ccp_alpha=ccp_alpha), y_train)
    pruned_score = pruned_clf.score(X_val, y_val)
    if pruned_score > score:
        score = pruned_score
        clf = pruned_clf

# Print the decision tree after pruning

Use cross-validation to evaluate model performance

Cross-validation is a technique for evaluating the performance of a model by training and testing it on different subsets of the data. This can help prevent overfitting by testing the model’s ability to generalize to new data.

In this example we use cross_val_score from scikit liearn.

from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
dtc = DecisionTreeClassifier()
scores = cross_val_score(dtc, X, y, cv=10)
print("Cross-validation scores: {}".format(scores))

Limit the depth of the tree

Limiting the depth of the tree can prevent the model from becoming too complex and overfitting to the training data. This can be done by setting a maximum depth or a minimum number of samples required for a node to be split.

from sklearn.tree import DecisionTreeClassifier
dtc = DecisionTreeClassifier(max_depth=5)

Use ensemble methods like random forests or boosting

Ensemble methods combine multiple decision trees to improve the model’s accuracy and prevent overfitting. Random forests create a collection of decision trees by randomly sampling the data and features for each tree, while boosting iteratively trains decision trees on the residual errors of the previous trees.

Here is an example of using the GradientBoostingClassifier from scikit learn.

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_classes=2, random_state=42)

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# Fit gradient boosting classifier to training data
gb = GradientBoostingClassifier(n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42), y_train)

# Evaluate performance on test data
print("Accuracy: {:.2f}".format(gb.score(X_test, y_test)))

Feature selection and engineering to reduce noise in the data

Feature selection involves selecting the most relevant features for the model, while feature engineering involves creating new features or transforming existing ones to better capture the underlying patterns in the data. This can help reduce noise in the data and prevent the model from overfitting to irrelevant or noisy features.

from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import chi2
X_new = SelectKBest(chi2, k=10).fit_transform(X, y)

Feature Selection Techniques in Decision Trees

Feature selection is an important step in machine learning to identify the most relevant features or attributes that contribute the most to the prediction or decision-making process. In decision trees, feature selection is typically done during the tree construction process when determining the splitting criterion. There are several techniques for feature selection in decision trees:

Feature Importance

Decision trees can also provide a measure of feature importance, which indicates the relative importance of each feature in the decision-making process. Feature importance is calculated based on the number of times a feature is used for splitting across all nodes in the tree and the improvement in the impurity measure (such as entropy or Gini impurity) achieved by each split. Features with higher importance values are considered more relevant and contribute more to the decision-making process.

Recursive Feature Elimination

Recursive feature elimination is a technique that recursively removes less important features from the decision tree based on their importance values. The decision tree is repeatedly trained with the remaining features, and the feature with the lowest importance value is removed at each iteration. This process is repeated until a desired number of features or a desired level of feature importance is achieved.

Become a Machine Learning Engineer with Experience and implement decision trees in production environments


  1. Quinlan, J. R. (1986). Induction of decision trees. Machine learning, 1(1), 81-106. Link:
  2. Hastie, T., Tibshirani, R., & Friedman, J. (2009). The elements of statistical learning: data mining, inference, and prediction. Springer Science & Business Media. Link:
  3. Bishop, C. M. (2006). Pattern recognition and machine learning. Springer. Link:
  4. Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V., Thirion, B., Grisel, O., … & Vanderplas, J. (2011). Scikit-learn: Machine learning in Python. Journal of machine learning research, 12(Oct), 2825-2830. Link:
  5. Kohavi, R., & Quinlan, J. R. (2002). Data mining tasks and methods: Classification: decision-tree discovery. Handbook of data mining and knowledge discovery, 267-276. Link:
  6. W. Loh, (2014). Fifty Years of Classification and Regression Trees 1. Link:

Frequently asked questions about decision trees in machine learning

  1. What is a decision tree in machine learning?

    A decision tree is a graphical representation of a decision-making process or decision rules, where each internal node represents a decision based on a feature or attribute, and each leaf node represents an outcome or decision class.

  2. What are the advantages of using decision trees?

    Decision trees are easy to understand and interpret, can handle both categorical and numerical data, require minimal data preparation, can handle missing values, and are capable of handling both classification and regression tasks.

  3. What are the common splitting criteria used in decision tree algorithms?

    Some common splitting criteria used in decision tree algorithms include Gini impurity, entropy, and information gain, which are used to determine the best attribute for splitting the data at each node.

  4. How can decision trees be used for feature selection?

    Decision trees can be used for feature selection by analyzing the feature importance or feature ranking obtained from the decision tree, which can help identify the most important features for making accurate predictions.

  5. What are the methods to avoid overfitting in decision trees?

    Some methods to avoid overfitting in decision trees include pruning techniques such as pre-pruning (e.g., limiting the depth of the tree) and post-pruning (e.g., pruning the tree after it is fully grown and then removing less important nodes), and using ensemble methods such as random forests and boosting.

  6. What are the limitations of decision trees?

    Some limitations of decision trees include their susceptibility to overfitting, sensitivity to small changes in the data, lack of robustness to noise and outliers, and difficulty in handling continuous or large-scale datasets.

  7. What are the common applications of decision trees in real-world problems?

    Decision trees are commonly used in various real-world problems, including classification tasks such as spam detection, medical diagnosis, and credit risk assessment, as well as regression tasks such as housing price prediction, demand forecasting, and customer churn prediction.

  8. Can decision trees handle missing values in the data?

    Yes, decision trees can handle missing values in the data by using techniques such as surrogate splitting, where an alternative splitting rule is used when the value of a certain attribute is missing for a data point.

  9. Can decision trees be used for multi-class classification problems?

    Yes, decision trees can be used for multi-class classification problems by extending the binary splitting criteria to handle multiple classes, such as one-vs-rest or one-vs-one approaches.

  10. How can I implement decision trees in Python?

    Decision trees can be implemented in Python using popular machine learning libraries such as scikit-learn, TensorFlow, and PyTorch, which provide built-in functions and classes for training and evaluating decision tree models.

  11. Is decision tree a supervised or unsupervised algorithm?

    A decision tree is a supervised learning algorithm that is used for classification and regression modeling.

  12. What is pruning in decision trees?

    Pruning is a technique used in decision tree algorithms to reduce the size of the tree by removing nodes or branches that do not contribute significantly to the accuracy of the model. This helps to avoid overfitting and improve the generalization performance of the model.

  13. What are the benefits of pruning?

    Pruning helps to simplify and interpret the decision tree model by reducing its size and complexity. It also improves the generalization performance of the model by reducing overfitting and increasing accuracy on new, unseen data.

  14. What are the different types of pruning for decision trees?

    There are two main types of pruning: pre-pruning and post-pruning. Pre-pruning involves stopping the tree construction process before it reaches its maximum depth or minimum number of samples per leaf, while post-pruning involves constructing the decision tree to its maximum depth and then pruning back unnecessary branches or nodes.

  15. How is pruning performed in decision trees?

    Pruning can be performed by setting a maximum depth for the tree, a minimum number of samples per leaf, or a maximum number of leaf nodes for pre-pruning. For post-pruning, the model is trained on the training data, evaluated on a validation set, and then unnecessary branches or nodes are pruned based on a pruning criterion such as cost-complexity pruning.

  16. When should decision trees be pruned?

    Pruning should be used when the decision tree model is too complex or overfits the training data. It should also be used when the size of the decision tree becomes impractical for interpretation or implementation.

  17. Are there any drawbacks to pruning?

    One potential drawback of pruning is that it can result in a loss of information or accuracy if too many nodes or branches are pruned. Additionally, pruning can be computationally expensive, especially for large datasets or complex decision trees.

Posted on

SKLEARN LOGISTIC REGRESSION multiclass (more than 2) classification with Python scikit-learn

multiclass logistic regression with sklearn python

Logistic Regression is a commonly used machine learning algorithm for binary classification problems, where the goal is to predict one of two possible outcomes. However, in some cases, the target variable has more than two classes. In such cases, a multiclass classification problem is encountered. In this article, we will see how to create a logistic regression model using the scikit-learn library for multiclass classification problems.

Multinomial classification

Multinomial logistic regression is used when the dependent variable in question is nominal (equivalently categorical, meaning that it falls into any one of a set of categories that cannot be ordered in any meaningful way) and for which there are more than two categories. Some examples would be:

  • Which major will a college student choose, given their grades, stated likes and dislikes, etc.? 
  • Which blood type does a person have, given the results of various diagnostic tests? 
  • In a hands-free mobile phone dialing application, which person’s name was spoken, given various properties of the speech signal? 
  • Which candidate will a person vote for, given particular demographic characteristics? 
  • Which country will a firm locate an office in, given the characteristics of the firm and of the various candidate countries? 

These are all statistical classification problems. They all have in common a dependent variable to be predicted that comes from one of a limited set of items that cannot be meaningfully ordered, as well as a set of independent variables (also known as features, explanators, etc.), which are used to predict the dependent variable. Multinomial logistic regression is a particular solution to classification problems that use a linear combination of the observed features and some problem-specific parameters to estimate the probability of each particular value of the dependent variable. The best values of the parameters for a given problem are usually determined from some training data (e.g. some people for whom both the diagnostic test results and blood types are known, or some examples of known words being spoken).

Common Approaches

  • One-vs-Rest (OvR)
  • Softmax Regression (Multinomial Logistic Regression)
  • One vs One(OvO)

Multiclass classification problems are usually tackled in two ways – One-vs-Rest (OvR), One-vs-One (OvO) and using the softmax function. In the OvA / OvR approach, a separate binary classifier is trained for each class, where one class is considered positive and all other classes are considered negative. In the OvO approach, a separate binary classifier is trained for each pair of classes. For example, if there are k classes, then k(k-1)/2 classifiers will be trained in the OvO approach.

In this article, we will be using the OvR and softmax approach to create a logistic regression model for multiclass classification.

One-vs-Rest (OvR)

One-vs-rest (OvR for short, also referred to as One-vs-All or OvA) is a heuristic method for using binary classification algorithms for multi-class classification.

It involves splitting the multi-class dataset into multiple binary classification problems. A binary classifier is then trained on each binary classification problem and predictions are made using the model that is the most confident.

For example, given a multi-class classification problem with examples for each class ‘red,’ ‘blue,’ and ‘green‘. This could be divided into three binary classification datasets as follows:

  • Binary Classification Problem 1: red vs [blue, green]
  • Binary Classification Problem 2: blue vs [red, green]
  • Binary Classification Problem 3: green vs [red, blue]

A possible downside of this approach is that it requires one model to be created for each class. For example, three classes require three models. This could be an issue for large datasets (e.g. millions of rows), slow models (e.g. neural networks), or very large numbers of classes (e.g. hundreds of classes).

This approach requires that each model predicts a class membership probability or a probability-like score. The argmax of these scores (class index with the largest score) is then used to predict a class.

As such, the implementation of these algorithms in the scikit-learn library implements the OvR strategy by default when using these algorithms for multi-class classification.

Multi class logistic regression using one vs rest (OVR) strategy

The strategy for handling multi-class classification can be set via the “multi_class” argument and can be set to “ovr” for the one-vs-rest strategy when using sklearn’s LogisticRegression class from linear_model.

To start, we need to import the required libraries:

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

Next, we will load the load_iris dataset from the sklearn.datasets library, which is a commonly used dataset for multiclass classification problems:

iris = load_iris()
X =
y =

The load_iris dataset contains information about the sepal length, sepal width, petal length, and petal width of 150 iris flowers. The target variable is the species of the iris flower, which has three classes – 0, 1, and 2.

Next, we will split the data into training and testing sets. 80%-20% split:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

Training the multiclass logistic regression model

Now, we can create a logistic regression model and train it on the training data:

model = LogisticRegression(solver='lbfgs', multi_class='ovr'), y_train)

The multi_class parameter is set to ‘ovr’ to indicate that we are using the OvA approach for multiclass classification. The solver parameter is set to ‘lbfgs’ which is a suitable solver for small datasets like the load_iris dataset.

Next, we can evaluate the performance of the model on the test data:

y_pred = model.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print("Accuracy:", accuracy)

The predict method is used to make predictions on the test data, and the accuracy of the predictions is calculated by comparing the predicted values with the actual values.

Finally, we can use the trained model to make predictions on new data:

new_data = np.array([[5.1, 3.5, 1.4, 0.2]])
y_pred = model.predict(new_data)
print("Prediction:", y_pred)

In this example, we have taken a single new data point with sepal length 5.1, sepal width 3.5, petal length 1.4, and petal width 0.2. The model will return the predicted class for this data point.

Become a Machine Learning Engineer with Experience

Softmax Regression (Multinomial Logistic Regression)

The inputs to the multinomial logistic regression are the features we have in the dataset. Suppose if we are going to predict the Iris flower species type, the features will be the flower sepal length, width and petal length and width parameters will be our features. These features will be treated as the inputs for the multinomial logistic regression.

The keynote to remember here is the features values are always numerical. If the features are not numerical, we need to convert them into numerical values using the proper categorical data analysis techniques.

Linear Model

The linear model equation is the same as the linear equation in the linear regression model. You can see this linear equation in the image. Where the X is the set of inputs, Suppose from the image we can say X is a matrix. Which contains all the feature( numerical values) X = [x1,x2,x3]. Where W is another matrix includes the same input number of coefficients W = [w1,w2,w3].

In this example, the linear model output will be the w1x1, w2x2, w3*x3

Softmax Function 

The softmax function is a mathematical function that takes a vector of real numbers as input and outputs a probability distribution over the classes. It is often used in machine learning for multiclass classification problems, including neural networks and logistic regression models.

The softmax function is defined as:

softmax function used for multi class / multinomial logistic regression

The softmax function transforms the input vector into a probability distribution over the classes, where each class is assigned a probability between 0 and 1, and the sum of the probabilities is 1. The class with the highest probability is then selected as the predicted class.

The softmax function is a generalization of the logistic function used in binary classification. In binary classification, the logistic function is used to output a single probability value between 0 and 1, representing the probability of the input belonging to the positive class.

The softmax function is different from the sigmoid function, which is another function used in machine learning for binary classification. The sigmoid function outputs a value between 0 and 1, which can be interpreted as the probability of the input belonging to the positive class.

Cross Entropy

The cross-entropy is the last stage of multinomial logistic regression. Uses the cross-entropy function to find the similarity distance between the probabilities calculated from the softmax function and the target one-hot-encoding matrix.

Cross-entropy is a distance calculation function which takes the calculated probabilities from softmax function and the created one-hot-encoding matrix to calculate the distance. For the right target class, the distance value will be smaller, and the distance values will be larger for the wrong target class.

Multi class logistic regression using sklearn multinomial parameter

Multiclass logistic regression using softmax function (multinomial)

In the previous example, we created a logistic regression model for multiclass classification using the One-vs-All approach. In the softmax approach, the output of the logistic regression model is a vector of probabilities for each class. The class with the highest probability is then selected as the predicted class.

To use the softmax approach with logistic regression in scikit-learn, we need to set the multi_class parameter to ‘multinomial’ and the solver parameter to a solver that supports the multinomial loss function, such as ‘lbfgs’, ‘newton-cg’, or ‘sag’. Here’s an example of how to create a logistic regression model with multi_class set to ‘multinomial’:

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

iris = load_iris()
X =
y =

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

model = LogisticRegression(solver='lbfgs', multi_class='multinomial'), y_train)

y_pred = model.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print("Accuracy:", accuracy)

new_data = np.array([[5.1, 3.5, 1.4, 0.2]])
y_pred = model.predict(new_data)
print("Prediction:", y_pred)

In this example, we have set the multi_class parameter to ‘multinomial’ and the solver parameter to ‘lbfgs’. The lbfgs solver is suitable for small datasets like the load_iris dataset. We then train the logistic regression model on the training data and evaluate its performance on the test data.

We can also use the predict_proba method to get the probability estimates for each class for a given input. Here’s an example:

probabilities = model.predict_proba(new_data)
print("Probabilities:", probabilities)

In this example, we have used the predict_proba method to get the probability estimates for each class for the new data point. The output is a vector of probabilities for each class.

It’s important to note that the logistic regression model is a linear model and may not perform well on complex non-linear datasets. In such cases, other algorithms like decision trees, random forests, and support vector machines may perform better.


In conclusion, we have seen how to create a logistic regression model using the scikit-learn library for multiclass classification problems using the OvA and softmax approach. The softmax approach can be more accurate than the One-vs-All approach but can also be more computationally expensive. We have used the load_iris dataset for demonstration purposes but the same steps can be applied to any multiclass classification problem. It’s important to choose the right algorithm based on the characteristics of the dataset and the problem requirements.

  1. Can logistic regression be used for multiclass classification?

    Logistic regression is a binary classification model. To support multi-class classification problems, we would need to split the classification problem into multiple steps i.e. classify pairs of classes.

  2. Can you use logistic regression for a classification problem with three classes?

    Yes, we can apply logistic regression on 3 class classification problem. Use One Vs rest method for 3 class classification in logistic regression.

  3. When do I use predict_proba() instead of predict()?

    The predict() method is used to predict the actual class while predict_proba() method can be used to infer the class probabilities (i.e. the probability that a particular data point falls into the underlying classes). It is usually sufficient to use the predict() method to obtain the class labels directly. However, if you wish to futher fine tune your classification model e.g. threshold tuning, then you would need to use predict_proba()

  4. What is softmax function?

    The softmax function is a function that turns a vector of K real values into a vector of K real values that sum to 1. The input values can be positive, negative, zero, or greater than one, but the softmax transforms them into values between 0 and 1, so that they can be interpreted as probabilities. Learn more in this article.

  5. Why and when is Softmax used in logistic regression?

    The softmax function is used in classification algorithms where there is a need to obtain probability or probability distribution as the output. Some of these algorithms are the following: Neural networks. Multinomial logistic regression (Softmax regression)

  6. Why use softmax for classification?

    Softmax classifiers give you probabilities for each class label. It's much easier for us as humans to interpret probabilities to infer the class labels.

Posted on Leave a comment

sklearn Linear Regression in Python with sci-kit learn and easy examples

linear regression sklearn in python

Linear regression is a statistical method used for analyzing the relationship between a dependent variable and one or more independent variables. It is widely used in various fields, such as finance, economics, and engineering, to model the relationship between variables and make predictions. In this article, we will learn how to create a linear regression model using the scikit-learn library in Python.

Scikit-learn (also known as sklearn) is a popular Python library for machine learning that provides simple and efficient tools for data mining and data analysis. It provides a wide range of algorithms and models, including linear regression. In this article, we will use the sklearn library to create a linear regression model to predict the relationship between two variables.

Before we dive into the code, let’s first understand the basic concepts of linear regression.

Understanding Linear Regression

Linear regression is a method that models the relationship between a dependent variable (also known as the response variable or target variable) and one or more independent variables (also known as predictor variables or features). The goal of linear regression is to find the line of best fit that best predicts the dependent variable based on the independent variables.

In a simple linear regression, the relationship between the dependent variable and the independent variable is represented by the equation:

y = b0 + b1x

where y is the dependent variable, x is the independent variable, b0 is the intercept, and b1 is the slope.

The intercept b0 is the value of y when x is equal to zero, and the slope b1 represents the change in y for every unit change in x.

In multiple linear regression, the relationship between the dependent variable and multiple independent variables is represented by the equation:

y = b0 + b1x1 + b2x2 + ... + bnxn

where y is the dependent variable, x1, x2, …, xn are the independent variables, b0 is the intercept, and b1, b2, …, bn are the slopes.

Creating a Linear Regression Model in Python

Now that we have a basic understanding of linear regression, let’s dive into the code to create a linear regression model using the sklearn library in Python.

The first step is to import the necessary libraries and load the data. We will use the pandas library to load the data and the scikit-learn library to create the linear regression model.

Become a Data Analyst with Work Experience

import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression

Next, we will load the data into a pandas DataFrame. In this example, we will use a simple dataset that contains the height and weight of a group of individuals. The data consists of two columns, the height in inches and the weight in pounds. The goal is to fit a linear regression model to this data to find the relationship between the height and weight of individuals. The data can be represented in a 2-dimensional array, where each row represents a sample (an individual), and each column represents a feature (height and weight). The X data is the height of individuals and the y data is their corresponding weight.

height (inches)weight (pounds)
Heights and Weights of Individuals for a Linear Regression Model Exercise
# Load the data
df = pd.read_excel('data.xlsx')

Next, we will split the data into two arrays: X and y. X contains the independent variable (height) and y contains the dependent variable (weight).

# Split the data into X (independent variable) and y (dependent variable)
X = df['height'].values.reshape(-1, 1)
y = df['weight'].values

It’s always a good idea to check the shape of the data to ensure that it has been loaded correctly. We can use the shape attribute to check the shape of the arrays X and y.

# Check the shape of the data

The output should show that X has n rows and 1 column and y has n rows, where n is the number of samples in the dataset.

Perform simple cross validation

One common method for performing cross-validation on the data is to split the data into training and testing sets using the train_test_split function from the model_selection module of scikit-learn.

In this example, the data is first split into the X data, which is the height of individuals, and the y data, which is their corresponding weight. Then, the train_test_split function is used to split the data into training and testing sets. The test_size argument specifies the proportion of the data to use for testing, and the random_state argument sets the seed for the random number generator used to split the data.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

Train the linear regression model

Now that we have split the data into X and y, we can create a linear regression model using the LinearRegression class from the scikit-learn library. This same package is used to load and train the logistic regression model for classification. Learn more here.

# Create a linear regression model
reg = LinearRegression()

Next, we will fit the linear regression model to the data using the fit method.

# Fit the model to the data, y_train)

After fitting the model, we can access the intercept and coefficients using the intercept_ and coef_ attributes, respectively.

# Print the intercept and coefficients

The intercept and coefficients represent the parameters b0 and b1 in the equation y = b0 + b1x, respectively.

Finally, we can use the predict method to make predictions for new data.

# Make predictions for new data
new_data = np.array([[65]]) # Height of 65 inches
prediction = reg.predict(new_data)

This will output the predicted weight for a person with a height of 65 inches.

Cost functions for linear regression models

There are several cost functions that can be used to evaluate the linear regression model. Here are a few common ones:

  1. Mean Squared Error (MSE): MSE is the average of the squared differences between the predicted values and the actual values. The lower the MSE, the better the fit of the model. MSE is expressed as:
MSE = 1/n * Σ(y_i - y_i_pred)^2

where n is the number of samples, y_i is the actual value, and y_i_pred is the predicted value.

  1. Root Mean Squared Error (RMSE): RMSE is the square root of MSE. It is expressed as:
RMSE = √(1/n * Σ(y_i - y_i_pred)^2)
  1. Mean Absolute Error (MAE): MAE is the average of the absolute differences between the predicted values and the actual values. The lower the MAE, the better the fit of the model. MAE is expressed as:
MAE = 1/n * Σ|y_i - y_i_pred|
  1. R-Squared (R^2) a.k.a the coefficient of determination: R^2 is a measure of the goodness of fit of the linear regression model. It is the proportion of the variance in the dependent variable that is predictable from the independent variable. The R^2 value ranges from 0 to 1, where a value of 1 indicates a perfect fit and a value of 0 indicates a poor fit.

In scikit-learn, these cost functions can be easily computed using the mean_squared_error, mean_absolute_error, and r2_score functions from the metrics module. For example:

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

y_pred = model.predict(X_test)

# Mean Squared Error
mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error:", mse)

# Root Mean Squared Error
rmse = mean_squared_error(y_test, y_pred, squared = False)
print("Root Mean Squared Error:", rmse)

# Mean Absolute Error
mae = mean_absolute_error(y_test, y_pred)
print("Mean Absolute Error:", mae)

# R-Squared
r2 = r2_score(y_test, y_pred)
print("R-Squared:", r2)

These cost functions provide different perspectives on the performance of the linear regression model and can be used to choose the best model for a given problem.


In this article, we learned how to create a linear regression model using the scikit-learn library in Python. We first split the data into X and y, created a linear regression model, fit the model to the data, and finally made predictions for new data.

Linear regression is a simple and powerful method for analyzing the relationship between variables. By using the scikit-learn library in Python, we can easily create and fit linear regression models to our data and make predictions.

Frequently Asked Questions about Linear Regression with Sklearn in Python

  1. Which Python library is best for linear regression?

    scikit-learn (sklearn) is one of the best Python libraries for statistical analysis and machine learning and it is adapted for training models and making predictions. It offers several options for numerical calculations and statistical modelling. LinearRegression is an important sub-module to perform linear regression modelling.

  2. What is linear regression used for?

    Linear regression analysis is used to predict the value of a target variable based on the value of one or more independent variables. The variable you want to predict / explain is called the dependent or target variable. The variable you are using to predict the dependent variable's value is called the independent or feature variable.

  3. What are the 2 most common models of regression analysis?

    The most common models are simple linear and multiple linear. Nonlinear regression analysis is commonly used for more complicated data sets in which the dependent and independent variables show a nonlinear relationship. Regression analysis offers numerous applications in various disciplines.

  4. What are the advantages of linear regression?

    The biggest advantage of linear regression models is linearity: It makes the estimation procedure simple and, most importantly, these linear equations have an easy to understand interpretation on a modular level (i.e. the weights).

  5. What is the difference between correlation and linear regression?

    Correlation quantifies the strength of the linear relationship between a pair of variables, whereas regression expresses the relationship in the form of an equation.

  6. What is LinearRegression in Sklearn?

    LinearRegression fits a linear model with coefficients w = (w1, …, wp) to minimize the residual sum of squares between the observed targets in the dataset, and the targets predicted by the linear approximation.

  7. What is the full form of sklearn?

    scikit-learn (also known as sklearn) is a free software machine learning library for the Python programming language.

  8. What is the syntax for linear regression model in Python?

    from sklearn.linear_model import LinearRegression
    lr = LinearRegression(),y)

Posted on Leave a comment

Basic Python Data Structures – Lists, tuples, sets, dictionaries

lists tuples sets dictionaries python

List methods

  • list.append(x)
    Add an item to the end of the list. Equivalent to a[len(a):] = [x].
  • list.extend(iterable)
    Extend the list by appending all the items from the iterable. Equivalent to a[len(a):] = iterable.
  • list.insert(i, x)
    Insert an item at a given position. The first argument is the index of the element before which to insert, so a.insert(0, x) inserts at the front of the list, and a.insert(len(a), x) is equivalent to a.append(x).
  • list.remove(x)
    Remove the first item from the list whose value is equal to x. It raises a ValueError if there is no such item.
  • list.pop([i])
    Remove the item at the given position in the list, and return it. If no index is specified, a.pop() removes and returns the last item in the list. (The square brackets around the i in the method signature denote that the parameter is optional, not that you should type square brackets at that position. You will see this notation frequently in the Python Library Reference.)
  • list.clear()
    Remove all items from the list. Equivalent to del a[:].
  • list.index(x[, start[, end]])
    Return zero-based index in the list of the first item whose value is equal to x. Raises a ValueError if there is no such item. The optional arguments start and end are interpreted as in the slice notation and are used to limit the search to a particular subsequence of the list. The returned index is computed relative to the beginning of the full sequence rather than the start argument.
  • list.count(x)
    Return the number of times x appears in the list.
  • list.sort(*, key=None, reverse=False)
    Sort the items of the list in place (the arguments can be used for sort customization, see sorted() for their explanation).
  • list.reverse()
    Reverse the elements of the list in place.
  • list.copy()
    Return a shallow copy of the list. Equivalent to a[:].

Tuple methods

  • tuple.count()
    Returns the number of times a specified value occurs in a tuple.
  • tuple.index()
    Searches the tuple for a specified value and returns the position of where it was found.

Set methods

  • set.add()
    Adds an element to the set.
  • set.clear()
    Removes all the elements from the set.
  • set.copy()
    Returns a copy of the set.
  • set.difference()
    Returns a set containing the difference between two or more sets.
  • set.difference_update()
    Removes the items in this set that are also included in another, specified set.
  • set.discard()
    Remove the specified item.
  • set.intersection()
    Returns a set, that is the intersection of two or more sets.
  • set.intersection_update()
    Removes the items in this set that are not present in other, specified set(s).
  • set.isdisjoint()
    Returns whether two sets have a intersection or not.
  • set.issubset()
    Returns whether another set contains this set or not.
  • set.issuperset()
    Returns whether this set contains another set or not.
  • set.pop()
    Removes an element from the set.
  • set.remove()
    Removes the specified element.
  • set.symmetric_difference()
    Returns a set with the symmetric differences of two sets.
  • set.symmetric_difference_update()
    Inserts the symmetric differences from this set and another.
  • set.union()
    Return a set containing the union of sets.
  • set.update()
    Update the set with another set, or any other iterable.

Dictionary methods

  • dict.clear()
    Removes all the elements from the dictionary.
  • dict.copy()
    Returns a copy of the dictionary.
  • dict.fromkeys()
    Returns a dictionary with the specified keys and value.
  • dict.get()
    Returns the value of the specified key.
  • dict.items()
    Returns a list containing a tuple for each key value pair.
  • dict.keys()
    Returns a list containing the dictionary’s keys.
  • dict.pop()
    Removes the element with the specified key.
  • dict.popitem()
    Removes the last inserted key-value pair.
  • dict.setdefault()
    Returns the value of the specified key. If the key does not exist: insert the key, with the specified value.
  • dict.update()
    Updates the dictionary with the specified key-value pairs.
  • dict.values()
    Returns a list of all the values in the dictionary.