In this article we're going to look into k-means and how it can be visualized using Python. Below we have the code necessary to create an animation. You may simply copy and paste it to create your own animation of k-means. However, we also explain each of the functions below it in the explanation part.
Note: We sacrifice some modularity/extensiblity and break certain code principles in this example to maximize its readability and minimize the amount of code required to reach the goal described in the title.
The K-Means Animation
Code to visualize K-Means
from __future__ import annotationsimport mathimport sysfrom io import BytesIOfrom typing import Listimport imageioimport matplotlib.pyplot as pltimport numpy as npN_POINTS = 500 # The number of points to clusterN_CLUSTERS = 5 # The number of clustersCOLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each clusterEPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.class Point:def __init__(self, x: float, y: float, color="grey", magnitude=20):""" Point constructor """self.x = x # The x coordinateself.y = y # The y coordinateself.color = color # The points colorsself.magnitude = magnitude # The magnitude of the point (size in pyplot)def distance_to_point(self, point: Point):""" Calculates the points distance from another point """delta_x = self.x - point.x # The difference in x directiondelta_y = self.y - point.y # The difference in y directionreturn math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other pointclass PointList:def __init__(self, points: List[Point] = None, marker: str = "x"):""" PointList Constructor """if points:self.points = points # points may be a list of Pointselse:self.points = [] # or None by defaultself.marker = marker # The marker (symbol in the pyplot)@propertydef x_values(self):""" A list of the x values of all points in the list """return [point.x for point in self.points]@propertydef y_values(self):""" A list of the y values of all points in the list """return [point.y for point in self.points]@propertydef colors(self):""" A list of the colors of all points in the list """return [point.color for point in self.points]@propertydef magnitudes(self):""" A list of the magnitudes of all points in the list """return [point.magnitude for point in self.points]def plot(self):""" Returns a scatter plot of itself """return plt.scatter(x=self.x_values,y=self.y_values,color=self.colors,marker=self.marker,s=self.magnitudes)def append(self, point):""" Adds a point to the PointList """self.points.append(point)def len(self):""" Returns the length of the PointsList """return len(self.points)@propertydef x_sum(self):""" The sum of the x values of all points in the list """return sum(self.x_values)@propertydef y_sum(self):""" The sum of the y values of all points in the list """return sum(self.y_values)@propertydef x_avg(self):""" The average of the x values of all points in the list """return self.x_sum / self.len()@propertydef y_avg(self):""" The average of the y values of all points in the list """return self.y_sum / self.len()def difference(self, other_points_list: PointList) -> float:""" Returns the distance between points of two lists """differences = []for own_point, list_point in zip(self.points, other_points_list.points):differences.append((own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2)return math.sqrt(sum(differences))def random_point(**kwargs):""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """x = np.random.rand(1)y = np.random.rand(1)return Point(x, y, **kwargs)def random_points(n: int):""" Create a list of n random points """points = PointList()for _ in range(n):points.append(random_point())return pointsdef create_random_cluster_centres(k: int):""" Returns k random centres """centres = PointList(marker="o")for color, _ in zip(colors, range(k)):centres.append(random_point(color=color, magnitude=150))return centresdef create_k_point_lists(k: int):""" Returns k instances of PointLists"""point_lists = []for _ in range(k):point_lists.append(PointList())return point_listsdef cluster_points(points: PointList, centres: PointList) -> List[PointList]:""" Clusters points """points = points.points # deconstructcentres = centres.points # deconstructk = len(centres) # Number of centresclusters = create_k_point_lists(k) # Create k clustersfor point in points: # Iterate over each pointdistances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centremin_distance = min(distances) # Get the shortest distancecentre_index = distances.index(min_distance) # Get the centre with this distancecentre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distanceclusters[centre_index].append(point) # Add the point to the list of the cluster with the centrepoint.color = centre.color # colorize the point in the color of the centre (visualization only)return clusters # New clustersdef calculate_new_centres(clusters: List[PointList]):""" Calculates the new centres of k clusters """new_centres = PointList(marker="o") # Create a new point list for the centresfor cluster in clusters: # Iterate over each clusternew_centres.append(Point(x=cluster.x_avg, # New x coordinate equals the average x value of all points in the clustery=cluster.y_avg, # New y coordinate equals the average y value of all points in the clustercolor=cluster.colors[0], # The color of the first pointmagnitude=150 # Centres are display a bit larger to identify them visually))return new_centresdef plot_styling():plt.figure(facecolor="#111827") # Background coloraxis = plt.gca() # Create axis objectaxis.set_facecolor("#111827") # axis background coloraxis.spines['bottom'].set_color('white') # axis bottom coloraxis.spines['top'].set_color('white') # axis top coloraxis.spines['right'].set_color('white') # axis right coloraxis.spines['left'].set_color('white') # axis left coloraxis.tick_params(axis='x', colors='white') # x axis ticks coloraxis.tick_params(axis='y', colors='white') # y axis ticks colordef k_means(points: PointList, centres: PointList):""" The actual algorithm. With lots of stuff only for visualization """difference = 1 # initial differencen = 1 # initial iteration counterwhile abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zeronew_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centresnew_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their pointsdifference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one# Animation onlyplot_styling() # Apply some stylespoints.plot() # Plot the pointscentres.plot() # Plot the centresplt.title(f'Iteration {n}', color="white") # Create a title based on the iterationframe_bytes = BytesIO() # Create an in-memory bytes objectplt.savefig(frame_bytes) # Save the plot to the in-memory bytesframes.append(frame_bytes) # Append the plot to the framesplt.close("all") # Close all plots to free the memory# / Animation onlycentres = new_centres # Set the new centres as the centres for the next iterationn += 1 # Increment the iteration countwith imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the framesfor frame in frames: # Iterate over each frameimage = imageio.imread(frame) # Read the byteswriter.append_data(image) # Append it to the animationif __name__ == '__main__':points = random_points(N_POINTS)centres = create_random_cluster_centres(k=N_CLUSTERS)frames = list()# Initial Viewplot_styling()points.plot()centres.plot()plt.title("Initialization", color="white")frame_bytes = BytesIO()plt.savefig(frame_bytes)frames.append(frame_bytes)# Run the Algorithmk_means(points, centres)
Explanation
The code above should run without any issues. If you encounter any though, feel free to commit on the page bottom. However, if you modified the code and run into issues the following explanations might help.
Libraries
First of all we need some libraries. We use:
- the
__future__
package to be able to use type hints before their definition - the Python built-in module
math
for a few mathematical methods - the
sys
package to determine the epsilon of the float datatype io
to store byte objects inmemory- the
typing
package for type hints imageio
to create and save the frames of our animation.pyplot
frommatplotlib
-as the name let's one assume- for plotting the data, or here, the pictures as they're nothing else than matrices themselves. By conventionmatplotlib.pyplot
is imported asplt
numpy
to ease the array handling and outsource some operations to the way faster C-based libraries included. By convention, we importnumpy
asnp
to reduce the length of the code lines
from __future__ import annotationsimport mathimport sysfrom io import BytesIOfrom typing import Listimport imageioimport matplotlib.pyplot as pltimport numpy as npN_POINTS = 500 # The number of points to clusterN_CLUSTERS = 5 # The number of clustersCOLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each clusterEPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.
We also define a few static variables, respectively:
N_POINTS
for the number of dots we want to display in our k-means algorithmN_CLUSTERS
for the number of clusters (equals k)COLORS
for the colors of the clustersEPSILON
to calculate zero with the machine inaccuracy correctly
Point in a 2-dimensional Plane
We start the code by defining a class to repsent a point in a 2-dimensional plane. It has a single and simple method, to determine its Euclidean distance to another point.
Short recap about the euclidean distance: The Euclidean distance of two points
p
and q
is calculated using the square root of the sum of the squared delta
between their dimensional coordinates; or as a formula:
In other words, as we only need 2 dimensions for our points, it is merely the
square root of their squared x
and y
delta
, as shown in the code below.
class Point:def __init__(self, x: float, y: float, color="grey", magnitude=20):""" Point constructor """self.x = x # The x coordinateself.y = y # The y coordinateself.color = color # The points colorsself.magnitude = magnitude # The magnitude of the point (size in pyplot)def distance_to_point(self, point: Point):""" Calculates the points distance from another point """delta_x = self.x - point.x # The difference in x directiondelta_y = self.y - point.y # The difference in y directionreturn math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other point
Additionally, we create an attribute color and magnitude to make it easier to visually distinguish the points.
Point List
To ease handling these points a bit, we create a data structure called
PointList
which is a mere interpretation of a list with a few methods
specifically tailored to handling and plotting points.
class PointList:def __init__(self, points: List[Point] = None, marker: str = "x"):""" PointList Constructor """if points:self.points = points # points may be a list of Pointselse:self.points = [] # or None by defaultself.marker = marker # The marker (symbol in the pyplot)@propertydef x_values(self):""" A list of the x values of all points in the list """return [point.x for point in self.points]@propertydef y_values(self):""" A list of the y values of all points in the list """return [point.y for point in self.points]@propertydef colors(self):""" A list of the colors of all points in the list """return [point.color for point in self.points]@propertydef magnitudes(self):""" A list of the magnitudes of all points in the list """return [point.magnitude for point in self.points]def plot(self):""" Returns a scatter plot of itself """return plt.scatter(x=self.x_values,y=self.y_values,color=self.colors,marker=self.marker,s=self.magnitudes)def append(self, point):""" Adds a point to the PointList """self.points.append(point)def len(self):""" Returns the length of the PointsList """return len(self.points)@propertydef x_sum(self):""" The sum of the x values of all points in the list """return sum(self.x_values)@propertydef y_sum(self):""" The sum of the y values of all points in the list """return sum(self.y_values)@propertydef x_avg(self):""" The average of the x values of all points in the list """return self.x_sum / self.len()@propertydef y_avg(self):""" The average of the y values of all points in the list """return self.y_sum / self.len()def difference(self, other_points_list: PointList) -> float:""" Returns the distance between points of two lists """differences = []for own_point, list_point in zip(self.points, other_points_list.points):differences.append((own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2)return math.sqrt(sum(differences))
We make use of the Python built-in decorator property
for various properties
of this list type. It is mere syntactic sugar and transforms class methods to behave
like class attributes.
Create Random PointList
To create some data we can cluster, we write two simple functions that create
random Points
in a 1x1 sized plane, also we create a PointList
that contains
n
of these random points.
def random_point(**kwargs):""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """x = np.random.rand(1)y = np.random.rand(1)return Point(x, y, **kwargs)def random_points(n: int):""" Create a list of n random points """points = PointList()for _ in range(n):points.append(random_point())return points
Create Random K-Clusters
The plain k-means algorithm is initialized by using randomly placed centres so we create a helper function that does exactly that.
def create_random_cluster_centres(k: int):""" Returns k random centres """centres = PointList(marker="o")for color, _ in zip(COLORS, range(k)):centres.append(random_point(color=color, magnitude=150))return centres
Besides being represented by another marker, there is basically no difference between a point and a cluster centre, which is why we reuse creating a random point.
Create K Point Lists
Another helper function we're going to use is the create_k_point_lists
method.
It will return a PointList
k
times, to ease handling them.
def create_k_point_lists(k: int):""" Returns k instances of PointLists"""point_lists = []for _ in range(k):point_lists.append(PointList())return point_lists
Cluster Points
The first real submethod of k-means is the cluster points method. It recieves a
PointList
of the points and a PointList
of the centres and assigns each
point to a new cluster list which is based on the centre it is the closest to
def cluster_points(points: PointList, centres: PointList) -> List[PointList]:""" Clusters points """points = points.points # deconstructcentres = centres.points # deconstructk = len(centres) # Number of centresclusters = create_k_point_lists(k) # Create k clustersfor point in points: # Iterate over each pointdistances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centremin_distance = min(distances) # Get the shortest distancecentre_index = distances.index(min_distance) # Get the centre with this distancecentre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distanceclusters[centre_index].append(point) # Add the point to the list of the cluster with the centrepoint.color = centre.color # colorize the point in the color of the centre (visualization only)return clusters # New clusters
Calculate New Centres
Another submethod of k-means is to calculate_new_centres
based on the average
x
and y
coordinates of all the points in the cluster.
def calculate_new_centres(clusters: List[PointList]):""" Calculates the new centres of k clusters """new_centres = PointList(marker="o") # Create a new point list for the centresfor cluster in clusters: # Iterate over each clusternew_centres.append(Point(x=cluster.x_avg, # New x coordinate equals the average x value of all points in the clustery=cluster.y_avg, # New y coordinate equals the average y value of all points in the clustercolor=cluster.colors[0], # The color of the first pointmagnitude=150 # Centres are display a bit larger to identify them visually))return new_centres
Plot Styling
Not much to say about this code segment. Its simply applying some styling settings for the plot to meet the general styling of this page.
def plot_styling():plt.figure(facecolor="#111827") # Background coloraxis = plt.gca() # Create axis objectaxis.set_facecolor("#111827") # axis background coloraxis.spines['bottom'].set_color('white') # axis bottom coloraxis.spines['top'].set_color('white') # axis top coloraxis.spines['right'].set_color('white') # axis right coloraxis.spines['left'].set_color('white') # axis left coloraxis.tick_params(axis='x', colors='white') # x axis ticks coloraxis.tick_params(axis='y', colors='white') # y axis ticks color
K-Means Algorithm
Finally we put all the code together and run the k_means
algorithm. If the
files gets called directly the if __name__ == '__main__':
condition will be
met and the code listed below it will be executed.
def k_means(points: PointList, centres: PointList):""" The actual algorithm. With lots of stuff only for visualization """difference = 1 # initial differencen = 1 # initial iteration counterwhile abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zeronew_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centresnew_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their pointsdifference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous onecentres = new_centres # Set the new centres as the centres for the next iterationn += 1 # Increment the iteration countif __name__ == '__main__':points = random_points(N_POINTS)centres = create_random_cluster_centres(k=N_CLUSTERS)# Run the Algorithmk_means(points, centres)
We create a PointList
of N_POINTS
times random point for the points to
cluster and initialize k-means with N_CLUSTERS
times a random point as the
centres. When k-means is ran there are 3 steps that are repeated until the
centres of each cluster do not move anymore (in relation to the previous
iteration):
- Cluster the points by assigning them to the closest centre
- Calculate the new centres by getting the average position of each point in the cluster
- Calculate the relative distance of the centres in relation to the preivous iteration determine the termination condition is met
Create the animation
To create an animation of these, we take a kind of snapshot of each iteration at specific points in the algorithm. This modifies the last part of the code slightly:
def k_means(points: PointList, centres: PointList):""" The actual algorithm. With lots of stuff only for visualization """difference = 1 # initial differencen = 1 # initial iteration counterwhile abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zeronew_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centresnew_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their pointsdifference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one# Animation onlyplot_styling() # Apply some stylespoints.plot() # Plot the pointscentres.plot() # Plot the centresplt.title(f'Iteration {n}', color="white") # Create a title based on the iterationframe_bytes = BytesIO() # Create an in-memory bytes objectplt.savefig(frame_bytes) # Save the plot to the in-memory bytesframes.append(frame_bytes) # Append the plot to the framesplt.close("all") # Close all plots to free the memory# / Animation onlycentres = new_centres # Set the new centres as the centres for the next iterationn += 1 # Increment the iteration countwith imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the framesfor frame in frames: # Iterate over each frameimage = imageio.imread(frame) # Read the byteswriter.append_data(image) # Append it to the animationif __name__ == '__main__':points = random_points(N_POINTS)centres = create_random_cluster_centres(k=N_CLUSTERS)frames = list()# Initial Viewplot_styling()points.plot()centres.plot()plt.title("Initialization", color="white")frame_bytes = BytesIO()plt.savefig(frame_bytes)frames.append(frame_bytes)# Run the Algorithmk_means(points, centres)