Learn how to create and visualize the k-means algorithm - a very basic clustering algorithm that is often taugth in introductory data science classes

In this article we're going to look into k-means and how it can be visualized using Python. Below we have the code necessary to create an animation. You may simply copy and paste it to create your own animation of k-means. However, we also explain each of the functions below it in the explanation part.

Note: We sacrifice some modularity/extensiblity and break certain code principles in this example to maximize its readability and minimize the amount of code required to reach the goal described in the title.

## The K-Means Animation# ## Code to visualize K-Means#

from __future__ import annotations
import mathimport sysfrom io import BytesIOfrom typing import List
import imageioimport matplotlib.pyplot as pltimport numpy as np
N_POINTS = 500  # The number of points to clusterN_CLUSTERS = 5  # The number of clustersCOLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"]  # The colors of each clusterEPSILON = sys.float_info.epsilon  # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.

class Point:
def __init__(self, x: float, y: float, color="grey", magnitude=20):        """ Point constructor """        self.x = x  # The x coordinate        self.y = y  # The y coordinate        self.color = color  # The points colors        self.magnitude = magnitude  # The magnitude of the point (size in pyplot)
def distance_to_point(self, point: Point):        """ Calculates the points distance from another point """        delta_x = self.x - point.x  # The difference in x direction        delta_y = self.y - point.y  # The difference in y direction        return math.sqrt(delta_x ** 2 + delta_y ** 2)  # The Euclidean distance to the other point

class PointList:
def __init__(self, points: List[Point] = None, marker: str = "x"):        """ PointList Constructor """        if points:            self.points = points  # points may be a list of Points        else:            self.points = []  # or None by default        self.marker = marker  # The marker (symbol in the pyplot)
@property    def x_values(self):        """ A list of the x values of all points in the list """        return [point.x for point in self.points]
@property    def y_values(self):        """ A list of the y values of all points in the list """        return [point.y for point in self.points]
@property    def colors(self):        """ A list of the colors of all points in the list """        return [point.color for point in self.points]
@property    def magnitudes(self):        """ A list of the magnitudes of all points in the list """        return [point.magnitude for point in self.points]
def plot(self):        """ Returns a scatter plot of itself """        return plt.scatter(            x=self.x_values,            y=self.y_values,            color=self.colors,            marker=self.marker,            s=self.magnitudes        )
def append(self, point):        """ Adds a point to the PointList """        self.points.append(point)
def len(self):        """ Returns the length of the PointsList """        return len(self.points)
@property    def x_sum(self):        """ The sum of the x values of all points in the list """        return sum(self.x_values)
@property    def y_sum(self):        """ The sum of the y values of all points in the list """        return sum(self.y_values)
@property    def x_avg(self):        """ The average of the x values of all points in the list """        return self.x_sum / self.len()
@property    def y_avg(self):        """ The average of the y values of all points in the list """        return self.y_sum / self.len()
def difference(self, other_points_list: PointList) -> float:        """ Returns the distance between points of two lists """        differences = []        for own_point, list_point in zip(self.points, other_points_list.points):            differences.append(                (own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2            )        return math.sqrt(sum(differences))

def random_point(**kwargs):    """ Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """    x = np.random.rand(1)    y = np.random.rand(1)    return Point(x, y, **kwargs)

def random_points(n: int):    """ Create a list of n random points """    points = PointList()    for _ in range(n):        points.append(random_point())    return points

def create_random_cluster_centres(k: int):    """ Returns k random centres """    centres = PointList(marker="o")    for color, _ in zip(colors, range(k)):        centres.append(random_point(color=color, magnitude=150))    return centres

def create_k_point_lists(k: int):    """ Returns k instances of PointLists"""    point_lists = []    for _ in range(k):        point_lists.append(PointList())    return point_lists

def cluster_points(points: PointList, centres: PointList) -> List[PointList]:    """ Clusters points """    points = points.points  # deconstruct    centres = centres.points  # deconstruct    k = len(centres)  # Number of centres    clusters = create_k_point_lists(k)  # Create k clusters
for point in points:  # Iterate over each point        distances = [point.distance_to_point(centre) for centre in centres]  # calculate the distance for each point to each centre        min_distance = min(distances)  # Get the shortest distance        centre_index = distances.index(min_distance)  # Get the centre with this distance        centre = centres[centre_index]  # The (new) centre of this point is the centre with the shortest distance        clusters[centre_index].append(point)  # Add the point to the list of the cluster with the centre        point.color = centre.color  # colorize the point in the color of the centre (visualization only)
return clusters  # New clusters

def calculate_new_centres(clusters: List[PointList]):    """ Calculates the new centres of k clusters """    new_centres = PointList(marker="o")  # Create a new point list for the centres    for cluster in clusters:  # Iterate over each cluster        new_centres.append(            Point(                x=cluster.x_avg,  # New x coordinate equals the average x value of all points in the cluster                y=cluster.y_avg,  # New y coordinate equals the average y value of all points in the cluster                color=cluster.colors,  # The color of the first point                magnitude=150  # Centres are display a bit larger to identify them visually            )        )    return new_centres

def plot_styling():    plt.figure(facecolor="#111827")  # Background color    axis = plt.gca()  # Create axis object    axis.set_facecolor("#111827")  # axis background color    axis.spines['bottom'].set_color('white')  # axis bottom color    axis.spines['top'].set_color('white')  # axis top color    axis.spines['right'].set_color('white')  # axis right color    axis.spines['left'].set_color('white')  # axis left color    axis.tick_params(axis='x', colors='white')  # x axis ticks color    axis.tick_params(axis='y', colors='white')  # y axis ticks color

def k_means(points: PointList, centres: PointList):    """ The actual algorithm. With lots of stuff only for visualization """
difference = 1  # initial difference    n = 1  # initial iteration counter
while abs(difference) >= abs(0 + EPSILON):  # The shift of the centres in relation to the last iteration is greater than zero        new_clusters = cluster_points(points, centres)  # Calculate new clusters based on the points and their centres        new_centres = calculate_new_centres(new_clusters)  # Calculate the new positions of the centres of each cluster based on their points        difference = new_centres.difference(centres)  # The distance between the centres of this iteration and the previous one
# Animation only        plot_styling()  # Apply some styles        points.plot()  # Plot the points        centres.plot()  # Plot the centres        plt.title(f'Iteration {n}', color="white")  # Create a title based on the iteration        frame_bytes = BytesIO()  # Create an in-memory bytes object        plt.savefig(frame_bytes)  # Save the plot to the in-memory bytes        frames.append(frame_bytes)  # Append the plot to the frames        plt.close("all")  # Close all plots to free the memory        # / Animation only
centres = new_centres  # Set the new centres as the centres for the next iteration        n += 1  # Increment the iteration count
with imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer:  # Create an animate gif from the frames        for frame in frames:  # Iterate over each frame            image = imageio.imread(frame)  # Read the bytes            writer.append_data(image)  # Append it to the animation

if __name__ == '__main__':    points = random_points(N_POINTS)    centres = create_random_cluster_centres(k=N_CLUSTERS)    frames = list()
# Initial View    plot_styling()    points.plot()    centres.plot()    plt.title("Initialization", color="white")    frame_bytes = BytesIO()    plt.savefig(frame_bytes)    frames.append(frame_bytes)
# Run the Algorithm    k_means(points, centres)

## Explanation#

The code above should run without any issues. If you encounter any though, feel free to commit on the page bottom. However, if you modified the code and run into issues the following explanations might help.

### Libraries#

First of all we need some libraries. We use:

• the __future__ package to be able to use type hints before their definition
• the Python built-in module math for a few mathematical methods
• the sys package to determine the epsilon of the float datatype
• io to store byte objects inmemory
• the typing package for type hints
• imageio to create and save the frames of our animation.
• pyplot from matplotlib -as the name let's one assume- for plotting the data, or here, the pictures as they're nothing else than matrices themselves. By convention matplotlib.pyplot is imported as plt
• numpy to ease the array handling and outsource some operations to the way faster C-based libraries included. By convention, we import numpy as np to reduce the length of the code lines
from __future__ import annotations
import mathimport sysfrom io import BytesIOfrom typing import List
import imageioimport matplotlib.pyplot as pltimport numpy as np
N_POINTS = 500  # The number of points to clusterN_CLUSTERS = 5  # The number of clustersCOLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"]  # The colors of each clusterEPSILON = sys.float_info.epsilon  # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.

We also define a few static variables, respectively:

• N_POINTS for the number of dots we want to display in our k-means algorithm
• N_CLUSTERS for the number of clusters (equals k)
• COLORS for the colors of the clusters
• EPSILON to calculate zero with the machine inaccuracy correctly

### Point in a 2-dimensional Plane#

We start the code by defining a class to repsent a point in a 2-dimensional plane. It has a single and simple method, to determine its Euclidean distance to another point.

Short recap about the euclidean distance: The Euclidean distance of two points p and q is calculated using the square root of the sum of the squared delta between their dimensional coordinates; or as a formula:

In other words, as we only need 2 dimensions for our points, it is merely the square root of their squared x and y delta, as shown in the code below.

class Point:
def __init__(self, x: float, y: float, color="grey", magnitude=20):        """ Point constructor """        self.x = x  # The x coordinate        self.y = y  # The y coordinate        self.color = color  # The points colors        self.magnitude = magnitude  # The magnitude of the point (size in pyplot)
def distance_to_point(self, point: Point):        """ Calculates the points distance from another point """        delta_x = self.x - point.x  # The difference in x direction        delta_y = self.y - point.y  # The difference in y direction        return math.sqrt(delta_x ** 2 + delta_y ** 2)  # The Euclidean distance to the other point

Additionally, we create an attribute color and magnitude to make it easier to visually distinguish the points.

### Point List#

To ease handling these points a bit, we create a data structure called PointList which is a mere interpretation of a list with a few methods specifically tailored to handling and plotting points.

class PointList:
def __init__(self, points: List[Point] = None, marker: str = "x"):        """ PointList Constructor """        if points:            self.points = points  # points may be a list of Points        else:            self.points = []  # or None by default        self.marker = marker  # The marker (symbol in the pyplot)
@property    def x_values(self):        """ A list of the x values of all points in the list """        return [point.x for point in self.points]
@property    def y_values(self):        """ A list of the y values of all points in the list """        return [point.y for point in self.points]
@property    def colors(self):        """ A list of the colors of all points in the list """        return [point.color for point in self.points]
@property    def magnitudes(self):        """ A list of the magnitudes of all points in the list """        return [point.magnitude for point in self.points]
def plot(self):        """ Returns a scatter plot of itself """        return plt.scatter(            x=self.x_values,            y=self.y_values,            color=self.colors,            marker=self.marker,            s=self.magnitudes        )
def append(self, point):        """ Adds a point to the PointList """        self.points.append(point)
def len(self):        """ Returns the length of the PointsList """        return len(self.points)
@property    def x_sum(self):        """ The sum of the x values of all points in the list """        return sum(self.x_values)
@property    def y_sum(self):        """ The sum of the y values of all points in the list """        return sum(self.y_values)
@property    def x_avg(self):        """ The average of the x values of all points in the list """        return self.x_sum / self.len()
@property    def y_avg(self):        """ The average of the y values of all points in the list """        return self.y_sum / self.len()
def difference(self, other_points_list: PointList) -> float:        """ Returns the distance between points of two lists """        differences = []        for own_point, list_point in zip(self.points, other_points_list.points):            differences.append(                (own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2            )        return math.sqrt(sum(differences))

We make use of the Python built-in decorator property for various properties of this list type. It is mere syntactic sugar and transforms class methods to behave like class attributes.

### Create Random PointList#

To create some data we can cluster, we write two simple functions that create random Points in a 1x1 sized plane, also we create a PointList that contains n of these random points.

def random_point(**kwargs):    """ Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """    x = np.random.rand(1)    y = np.random.rand(1)    return Point(x, y, **kwargs)

def random_points(n: int):    """ Create a list of n random points """    points = PointList()    for _ in range(n):        points.append(random_point())    return points

### Create Random K-Clusters#

The plain k-means algorithm is initialized by using randomly placed centres so we create a helper function that does exactly that.

def create_random_cluster_centres(k: int):    """ Returns k random centres """    centres = PointList(marker="o")    for color, _ in zip(COLORS, range(k)):        centres.append(random_point(color=color, magnitude=150))    return centres

Besides being represented by another marker, there is basically no difference between a point and a cluster centre, which is why we reuse creating a random point.

### Create K Point Lists#

Another helper function we're going to use is the create_k_point_lists method. It will return a PointList k times, to ease handling them.

def create_k_point_lists(k: int):    """ Returns k instances of PointLists"""    point_lists = []    for _ in range(k):        point_lists.append(PointList())    return point_lists

### Cluster Points#

The first real submethod of k-means is the cluster points method. It recieves a PointList of the points and a PointList of the centres and assigns each point to a new cluster list which is based on the centre it is the closest to

def cluster_points(points: PointList, centres: PointList) -> List[PointList]:    """ Clusters points """    points = points.points  # deconstruct    centres = centres.points  # deconstruct    k = len(centres)  # Number of centres    clusters = create_k_point_lists(k)  # Create k clusters
for point in points:  # Iterate over each point        distances = [point.distance_to_point(centre) for centre in centres]  # calculate the distance for each point to each centre        min_distance = min(distances)  # Get the shortest distance        centre_index = distances.index(min_distance)  # Get the centre with this distance        centre = centres[centre_index]  # The (new) centre of this point is the centre with the shortest distance        clusters[centre_index].append(point)  # Add the point to the list of the cluster with the centre        point.color = centre.color  # colorize the point in the color of the centre (visualization only)
return clusters  # New clusters

### Calculate New Centres#

Another submethod of k-means is to calculate_new_centres based on the average x and y coordinates of all the points in the cluster.

def calculate_new_centres(clusters: List[PointList]):    """ Calculates the new centres of k clusters """    new_centres = PointList(marker="o")  # Create a new point list for the centres    for cluster in clusters:  # Iterate over each cluster        new_centres.append(            Point(                x=cluster.x_avg,  # New x coordinate equals the average x value of all points in the cluster                y=cluster.y_avg,  # New y coordinate equals the average y value of all points in the cluster                color=cluster.colors,  # The color of the first point                magnitude=150  # Centres are display a bit larger to identify them visually            )        )    return new_centres

### Plot Styling#

def plot_styling():    plt.figure(facecolor="#111827")  # Background color    axis = plt.gca()  # Create axis object    axis.set_facecolor("#111827")  # axis background color    axis.spines['bottom'].set_color('white')  # axis bottom color    axis.spines['top'].set_color('white')  # axis top color    axis.spines['right'].set_color('white')  # axis right color    axis.spines['left'].set_color('white')  # axis left color    axis.tick_params(axis='x', colors='white')  # x axis ticks color    axis.tick_params(axis='y', colors='white')  # y axis ticks color

### K-Means Algorithm#

Finally we put all the code together and run the k_means algorithm. If the files gets called directly the if __name__ == '__main__': condition will be met and the code listed below it will be executed.

def k_means(points: PointList, centres: PointList):    """ The actual algorithm. With lots of stuff only for visualization """
difference = 1  # initial difference    n = 1  # initial iteration counter
while abs(difference) >= abs(0 + EPSILON):  # The shift of the centres in relation to the last iteration is greater than zero        new_clusters = cluster_points(points, centres)  # Calculate new clusters based on the points and their centres        new_centres = calculate_new_centres(new_clusters)  # Calculate the new positions of the centres of each cluster based on their points        difference = new_centres.difference(centres)  # The distance between the centres of this iteration and the previous one
centres = new_centres  # Set the new centres as the centres for the next iteration        n += 1  # Increment the iteration count

if __name__ == '__main__':    points = random_points(N_POINTS)    centres = create_random_cluster_centres(k=N_CLUSTERS)
# Run the Algorithm    k_means(points, centres)

We create a PointList of N_POINTS times random point for the points to cluster and initialize k-means with N_CLUSTERS times a random point as the centres. When k-means is ran there are 3 steps that are repeated until the centres of each cluster do not move anymore (in relation to the previous iteration):

1. Cluster the points by assigning them to the closest centre
2. Calculate the new centres by getting the average position of each point in the cluster
3. Calculate the relative distance of the centres in relation to the preivous iteration determine the termination condition is met

### Create the animation#

To create an animation of these, we take a kind of snapshot of each iteration at specific points in the algorithm. This modifies the last part of the code slightly:

def k_means(points: PointList, centres: PointList):    """ The actual algorithm. With lots of stuff only for visualization """
difference = 1  # initial difference    n = 1  # initial iteration counter
while abs(difference) >= abs(0 + EPSILON):  # The shift of the centres in relation to the last iteration is greater than zero        new_clusters = cluster_points(points, centres)  # Calculate new clusters based on the points and their centres        new_centres = calculate_new_centres(new_clusters)  # Calculate the new positions of the centres of each cluster based on their points        difference = new_centres.difference(centres)  # The distance between the centres of this iteration and the previous one
# Animation only        plot_styling()  # Apply some styles        points.plot()  # Plot the points        centres.plot()  # Plot the centres        plt.title(f'Iteration {n}', color="white")  # Create a title based on the iteration        frame_bytes = BytesIO()  # Create an in-memory bytes object        plt.savefig(frame_bytes)  # Save the plot to the in-memory bytes        frames.append(frame_bytes)  # Append the plot to the frames        plt.close("all")  # Close all plots to free the memory        # / Animation only
centres = new_centres  # Set the new centres as the centres for the next iteration        n += 1  # Increment the iteration count
with imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer:  # Create an animate gif from the frames        for frame in frames:  # Iterate over each frame            image = imageio.imread(frame)  # Read the bytes            writer.append_data(image)  # Append it to the animation

if __name__ == '__main__':    points = random_points(N_POINTS)    centres = create_random_cluster_centres(k=N_CLUSTERS)    frames = list()
# Initial View    plot_styling()    points.plot()    centres.plot()    plt.title("Initialization", color="white")    frame_bytes = BytesIO()    plt.savefig(frame_bytes)    frames.append(frame_bytes)
# Run the Algorithm    k_means(points, centres)