In this article we're going to look into k-means and how it can be visualized using Python. Below we have the code necessary to create an animation. You may simply copy and paste it to create your own animation of k-means. However, we also explain each of the functions below it in the explanation part.

**Note**: We sacrifice some modularity/extensiblity and break certain code
principles in this example to maximize its readability and minimize the amount of
code required to reach the goal described in the title.

## The K-Means Animation

## Code to visualize K-Means

from __future__ import annotationsimport mathimport sysfrom io import BytesIOfrom typing import Listimport imageioimport matplotlib.pyplot as pltimport numpy as npN_POINTS = 500 # The number of points to clusterN_CLUSTERS = 5 # The number of clustersCOLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each clusterEPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.class Point:def __init__(self, x: float, y: float, color="grey", magnitude=20):""" Point constructor """self.x = x # The x coordinateself.y = y # The y coordinateself.color = color # The points colorsself.magnitude = magnitude # The magnitude of the point (size in pyplot)def distance_to_point(self, point: Point):""" Calculates the points distance from another point """delta_x = self.x - point.x # The difference in x directiondelta_y = self.y - point.y # The difference in y directionreturn math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other pointclass PointList:def __init__(self, points: List[Point] = None, marker: str = "x"):""" PointList Constructor """if points:self.points = points # points may be a list of Pointselse:self.points = [] # or None by defaultself.marker = marker # The marker (symbol in the pyplot)@propertydef x_values(self):""" A list of the x values of all points in the list """return [point.x for point in self.points]@propertydef y_values(self):""" A list of the y values of all points in the list """return [point.y for point in self.points]@propertydef colors(self):""" A list of the colors of all points in the list """return [point.color for point in self.points]@propertydef magnitudes(self):""" A list of the magnitudes of all points in the list """return [point.magnitude for point in self.points]def plot(self):""" Returns a scatter plot of itself """return plt.scatter(x=self.x_values,y=self.y_values,color=self.colors,marker=self.marker,s=self.magnitudes)def append(self, point):""" Adds a point to the PointList """self.points.append(point)def len(self):""" Returns the length of the PointsList """return len(self.points)@propertydef x_sum(self):""" The sum of the x values of all points in the list """return sum(self.x_values)@propertydef y_sum(self):""" The sum of the y values of all points in the list """return sum(self.y_values)@propertydef x_avg(self):""" The average of the x values of all points in the list """return self.x_sum / self.len()@propertydef y_avg(self):""" The average of the y values of all points in the list """return self.y_sum / self.len()def difference(self, other_points_list: PointList) -> float:""" Returns the distance between points of two lists """differences = []for own_point, list_point in zip(self.points, other_points_list.points):differences.append((own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2)return math.sqrt(sum(differences))def random_point(**kwargs):""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """x = np.random.rand(1)y = np.random.rand(1)return Point(x, y, **kwargs)def random_points(n: int):""" Create a list of n random points """points = PointList()for _ in range(n):points.append(random_point())return pointsdef create_random_cluster_centres(k: int):""" Returns k random centres """centres = PointList(marker="o")for color, _ in zip(colors, range(k)):centres.append(random_point(color=color, magnitude=150))return centresdef create_k_point_lists(k: int):""" Returns k instances of PointLists"""point_lists = []for _ in range(k):point_lists.append(PointList())return point_listsdef cluster_points(points: PointList, centres: PointList) -> List[PointList]:""" Clusters points """points = points.points # deconstructcentres = centres.points # deconstructk = len(centres) # Number of centresclusters = create_k_point_lists(k) # Create k clustersfor point in points: # Iterate over each pointdistances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centremin_distance = min(distances) # Get the shortest distancecentre_index = distances.index(min_distance) # Get the centre with this distancecentre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distanceclusters[centre_index].append(point) # Add the point to the list of the cluster with the centrepoint.color = centre.color # colorize the point in the color of the centre (visualization only)return clusters # New clustersdef calculate_new_centres(clusters: List[PointList]):""" Calculates the new centres of k clusters """new_centres = PointList(marker="o") # Create a new point list for the centresfor cluster in clusters: # Iterate over each clusternew_centres.append(Point(x=cluster.x_avg, # New x coordinate equals the average x value of all points in the clustery=cluster.y_avg, # New y coordinate equals the average y value of all points in the clustercolor=cluster.colors[0], # The color of the first pointmagnitude=150 # Centres are display a bit larger to identify them visually))return new_centresdef plot_styling():plt.figure(facecolor="#111827") # Background coloraxis = plt.gca() # Create axis objectaxis.set_facecolor("#111827") # axis background coloraxis.spines['bottom'].set_color('white') # axis bottom coloraxis.spines['top'].set_color('white') # axis top coloraxis.spines['right'].set_color('white') # axis right coloraxis.spines['left'].set_color('white') # axis left coloraxis.tick_params(axis='x', colors='white') # x axis ticks coloraxis.tick_params(axis='y', colors='white') # y axis ticks colordef k_means(points: PointList, centres: PointList):""" The actual algorithm. With lots of stuff only for visualization """difference = 1 # initial differencen = 1 # initial iteration counterwhile abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zeronew_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centresnew_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their pointsdifference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one# Animation onlyplot_styling() # Apply some stylespoints.plot() # Plot the pointscentres.plot() # Plot the centresplt.title(f'Iteration {n}', color="white") # Create a title based on the iterationframe_bytes = BytesIO() # Create an in-memory bytes objectplt.savefig(frame_bytes) # Save the plot to the in-memory bytesframes.append(frame_bytes) # Append the plot to the framesplt.close("all") # Close all plots to free the memory# / Animation onlycentres = new_centres # Set the new centres as the centres for the next iterationn += 1 # Increment the iteration countwith imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the framesfor frame in frames: # Iterate over each frameimage = imageio.imread(frame) # Read the byteswriter.append_data(image) # Append it to the animationif __name__ == '__main__':points = random_points(N_POINTS)centres = create_random_cluster_centres(k=N_CLUSTERS)frames = list()# Initial Viewplot_styling()points.plot()centres.plot()plt.title("Initialization", color="white")frame_bytes = BytesIO()plt.savefig(frame_bytes)frames.append(frame_bytes)# Run the Algorithmk_means(points, centres)

## Explanation

The code above should run without any issues. If you encounter any though, feel free to commit on the page bottom. However, if you modified the code and run into issues the following explanations might help.

### Libraries

First of all we need some libraries. We use:

- the
`__future__`

package to be able to use type hints before their definition - the Python built-in module
`math`

for a few mathematical methods - the
`sys`

package to determine the epsilon of the float datatype `io`

to store byte objects inmemory- the
`typing`

package for type hints `imageio`

to create and save the frames of our animation.`pyplot`

from`matplotlib`

-as the name let's one assume- for plotting the data, or here, the pictures as they're nothing else than matrices themselves. By convention`matplotlib.pyplot`

is imported as`plt`

`numpy`

to ease the array handling and outsource some operations to the way faster C-based libraries included. By convention, we import`numpy`

as`np`

to reduce the length of the code lines

from __future__ import annotationsimport mathimport sysfrom io import BytesIOfrom typing import Listimport imageioimport matplotlib.pyplot as pltimport numpy as npN_POINTS = 500 # The number of points to clusterN_CLUSTERS = 5 # The number of clustersCOLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each clusterEPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.

We also define a few static variables, respectively:

`N_POINTS`

for the number of dots we want to display in our k-means algorithm`N_CLUSTERS`

for the number of clusters (equals k)`COLORS`

for the colors of the clusters`EPSILON`

to calculate zero with the machine inaccuracy correctly

### Point in a 2-dimensional Plane

We start the code by defining a class to repsent a point in a 2-dimensional plane. It has a single and simple method, to determine its Euclidean distance to another point.

Short recap about the euclidean distance: The Euclidean distance of two points
`p`

and `q`

is calculated using the square root of the sum of the squared delta
between their dimensional coordinates; or as a formula:

In other words, as we only need 2 dimensions for our points, it is merely the
square root of their squared `x`

and `y`

`delta`

, as shown in the code below.

class Point:def __init__(self, x: float, y: float, color="grey", magnitude=20):""" Point constructor """self.x = x # The x coordinateself.y = y # The y coordinateself.color = color # The points colorsself.magnitude = magnitude # The magnitude of the point (size in pyplot)def distance_to_point(self, point: Point):""" Calculates the points distance from another point """delta_x = self.x - point.x # The difference in x directiondelta_y = self.y - point.y # The difference in y directionreturn math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other point

Additionally, we create an attribute color and magnitude to make it easier to visually distinguish the points.

### Point List

To ease handling these points a bit, we create a data structure called
`PointList`

which is a mere interpretation of a list with a few methods
specifically tailored to handling and plotting points.

class PointList:def __init__(self, points: List[Point] = None, marker: str = "x"):""" PointList Constructor """if points:self.points = points # points may be a list of Pointselse:self.points = [] # or None by defaultself.marker = marker # The marker (symbol in the pyplot)@propertydef x_values(self):""" A list of the x values of all points in the list """return [point.x for point in self.points]@propertydef y_values(self):""" A list of the y values of all points in the list """return [point.y for point in self.points]@propertydef colors(self):""" A list of the colors of all points in the list """return [point.color for point in self.points]@propertydef magnitudes(self):""" A list of the magnitudes of all points in the list """return [point.magnitude for point in self.points]def plot(self):""" Returns a scatter plot of itself """return plt.scatter(x=self.x_values,y=self.y_values,color=self.colors,marker=self.marker,s=self.magnitudes)def append(self, point):""" Adds a point to the PointList """self.points.append(point)def len(self):""" Returns the length of the PointsList """return len(self.points)@propertydef x_sum(self):""" The sum of the x values of all points in the list """return sum(self.x_values)@propertydef y_sum(self):""" The sum of the y values of all points in the list """return sum(self.y_values)@propertydef x_avg(self):""" The average of the x values of all points in the list """return self.x_sum / self.len()@propertydef y_avg(self):""" The average of the y values of all points in the list """return self.y_sum / self.len()def difference(self, other_points_list: PointList) -> float:""" Returns the distance between points of two lists """differences = []for own_point, list_point in zip(self.points, other_points_list.points):differences.append((own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2)return math.sqrt(sum(differences))

We make use of the Python built-in decorator `property`

for various properties
of this list type. It is mere syntactic sugar and transforms class methods to behave
like class attributes.

### Create Random PointList

To create some data we can cluster, we write two simple functions that create
random `Points`

in a 1x1 sized plane, also we create a `PointList`

that contains
`n`

of these random points.

def random_point(**kwargs):""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """x = np.random.rand(1)y = np.random.rand(1)return Point(x, y, **kwargs)def random_points(n: int):""" Create a list of n random points """points = PointList()for _ in range(n):points.append(random_point())return points

### Create Random K-Clusters

The plain k-means algorithm is initialized by using randomly placed centres so we create a helper function that does exactly that.

def create_random_cluster_centres(k: int):""" Returns k random centres """centres = PointList(marker="o")for color, _ in zip(COLORS, range(k)):centres.append(random_point(color=color, magnitude=150))return centres

Besides being represented by another marker, there is basically no difference between a point and a cluster centre, which is why we reuse creating a random point.

### Create K Point Lists

Another helper function we're going to use is the `create_k_point_lists`

method.
It will return a `PointList`

`k`

times, to ease handling them.

def create_k_point_lists(k: int):""" Returns k instances of PointLists"""point_lists = []for _ in range(k):point_lists.append(PointList())return point_lists

### Cluster Points

The first real submethod of k-means is the cluster points method. It recieves a
`PointList`

of the points and a `PointList`

of the centres and assigns each
point to a new cluster list which is based on the centre it is the closest to

def cluster_points(points: PointList, centres: PointList) -> List[PointList]:""" Clusters points """points = points.points # deconstructcentres = centres.points # deconstructk = len(centres) # Number of centresclusters = create_k_point_lists(k) # Create k clustersfor point in points: # Iterate over each pointdistances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centremin_distance = min(distances) # Get the shortest distancecentre_index = distances.index(min_distance) # Get the centre with this distancecentre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distanceclusters[centre_index].append(point) # Add the point to the list of the cluster with the centrepoint.color = centre.color # colorize the point in the color of the centre (visualization only)return clusters # New clusters

### Calculate New Centres

Another submethod of k-means is to `calculate_new_centres`

based on the average
`x`

and `y`

coordinates of all the points in the cluster.

def calculate_new_centres(clusters: List[PointList]):""" Calculates the new centres of k clusters """new_centres = PointList(marker="o") # Create a new point list for the centresfor cluster in clusters: # Iterate over each clusternew_centres.append(Point(x=cluster.x_avg, # New x coordinate equals the average x value of all points in the clustery=cluster.y_avg, # New y coordinate equals the average y value of all points in the clustercolor=cluster.colors[0], # The color of the first pointmagnitude=150 # Centres are display a bit larger to identify them visually))return new_centres

### Plot Styling

Not much to say about this code segment. Its simply applying some styling settings for the plot to meet the general styling of this page.

def plot_styling():plt.figure(facecolor="#111827") # Background coloraxis = plt.gca() # Create axis objectaxis.set_facecolor("#111827") # axis background coloraxis.spines['bottom'].set_color('white') # axis bottom coloraxis.spines['top'].set_color('white') # axis top coloraxis.spines['right'].set_color('white') # axis right coloraxis.spines['left'].set_color('white') # axis left coloraxis.tick_params(axis='x', colors='white') # x axis ticks coloraxis.tick_params(axis='y', colors='white') # y axis ticks color

### K-Means Algorithm

Finally we put all the code together and run the `k_means`

algorithm. If the
files gets called directly the `if __name__ == '__main__':`

condition will be
met and the code listed below it will be executed.

def k_means(points: PointList, centres: PointList):""" The actual algorithm. With lots of stuff only for visualization """difference = 1 # initial differencen = 1 # initial iteration counterwhile abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zeronew_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centresnew_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their pointsdifference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous onecentres = new_centres # Set the new centres as the centres for the next iterationn += 1 # Increment the iteration countif __name__ == '__main__':points = random_points(N_POINTS)centres = create_random_cluster_centres(k=N_CLUSTERS)# Run the Algorithmk_means(points, centres)

We create a `PointList`

of `N_POINTS`

times random point for the points to
cluster and initialize k-means with `N_CLUSTERS`

times a random point as the
centres. When k-means is ran there are 3 steps that are repeated until the
centres of each cluster do not move anymore (in relation to the previous
iteration):

- Cluster the points by assigning them to the closest centre
- Calculate the new centres by getting the average position of each point in the cluster
- Calculate the relative distance of the centres in relation to the preivous iteration determine the termination condition is met

### Create the animation

To create an animation of these, we take a kind of snapshot of each iteration at specific points in the algorithm. This modifies the last part of the code slightly:

def k_means(points: PointList, centres: PointList):""" The actual algorithm. With lots of stuff only for visualization """difference = 1 # initial differencen = 1 # initial iteration counterwhile abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zeronew_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centresnew_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their pointsdifference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one# Animation onlyplot_styling() # Apply some stylespoints.plot() # Plot the pointscentres.plot() # Plot the centresplt.title(f'Iteration {n}', color="white") # Create a title based on the iterationframe_bytes = BytesIO() # Create an in-memory bytes objectplt.savefig(frame_bytes) # Save the plot to the in-memory bytesframes.append(frame_bytes) # Append the plot to the framesplt.close("all") # Close all plots to free the memory# / Animation onlycentres = new_centres # Set the new centres as the centres for the next iterationn += 1 # Increment the iteration countwith imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the framesfor frame in frames: # Iterate over each frameimage = imageio.imread(frame) # Read the byteswriter.append_data(image) # Append it to the animationif __name__ == '__main__':points = random_points(N_POINTS)centres = create_random_cluster_centres(k=N_CLUSTERS)frames = list()# Initial Viewplot_styling()points.plot()centres.plot()plt.title("Initialization", color="white")frame_bytes = BytesIO()plt.savefig(frame_bytes)frames.append(frame_bytes)# Run the Algorithmk_means(points, centres)