In this article we're going to look into k-means and how it can be visualized using Python. Below we have the code necessary to create an animation. You may simply copy and paste it to create your own animation of k-means. However, we also explain each of the functions below it in the explanation part.
Note: We sacrifice some modularity/extensiblity and break certain code principles in this example to maximize its readability and minimize the amount of code required to reach the goal described in the title.
The K-Means Animation
Code to visualize K-Means
from __future__ import annotations
import math
import sys
from io import BytesIO
from typing import List
import imageio
import matplotlib.pyplot as plt
import numpy as np
N_POINTS = 500 # The number of points to cluster
N_CLUSTERS = 5 # The number of clusters
COLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each cluster
EPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.
class Point:
def __init__(self, x: float, y: float, color="grey", magnitude=20):
""" Point constructor """
self.x = x # The x coordinate
self.y = y # The y coordinate
self.color = color # The points colors
self.magnitude = magnitude # The magnitude of the point (size in pyplot)
def distance_to_point(self, point: Point):
""" Calculates the points distance from another point """
delta_x = self.x - point.x # The difference in x direction
delta_y = self.y - point.y # The difference in y direction
return math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other point
class PointList:
def __init__(self, points: List[Point] = None, marker: str = "x"):
""" PointList Constructor """
if points:
self.points = points # points may be a list of Points
else:
self.points = [] # or None by default
self.marker = marker # The marker (symbol in the pyplot)
@property
def x_values(self):
""" A list of the x values of all points in the list """
return [point.x for point in self.points]
@property
def y_values(self):
""" A list of the y values of all points in the list """
return [point.y for point in self.points]
@property
def colors(self):
""" A list of the colors of all points in the list """
return [point.color for point in self.points]
@property
def magnitudes(self):
""" A list of the magnitudes of all points in the list """
return [point.magnitude for point in self.points]
def plot(self):
""" Returns a scatter plot of itself """
return plt.scatter(
x=self.x_values,
y=self.y_values,
color=self.colors,
marker=self.marker,
s=self.magnitudes
)
def append(self, point):
""" Adds a point to the PointList """
self.points.append(point)
def len(self):
""" Returns the length of the PointsList """
return len(self.points)
@property
def x_sum(self):
""" The sum of the x values of all points in the list """
return sum(self.x_values)
@property
def y_sum(self):
""" The sum of the y values of all points in the list """
return sum(self.y_values)
@property
def x_avg(self):
""" The average of the x values of all points in the list """
return self.x_sum / self.len()
@property
def y_avg(self):
""" The average of the y values of all points in the list """
return self.y_sum / self.len()
def difference(self, other_points_list: PointList) -> float:
""" Returns the distance between points of two lists """
differences = []
for own_point, list_point in zip(self.points, other_points_list.points):
differences.append(
(own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2
)
return math.sqrt(sum(differences))
def random_point(**kwargs):
""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """
x = np.random.rand(1)
y = np.random.rand(1)
return Point(x, y, **kwargs)
def random_points(n: int):
""" Create a list of n random points """
points = PointList()
for _ in range(n):
points.append(random_point())
return points
def create_random_cluster_centres(k: int):
""" Returns k random centres """
centres = PointList(marker="o")
for color, _ in zip(colors, range(k)):
centres.append(random_point(color=color, magnitude=150))
return centres
def create_k_point_lists(k: int):
""" Returns k instances of PointLists"""
point_lists = []
for _ in range(k):
point_lists.append(PointList())
return point_lists
def cluster_points(points: PointList, centres: PointList) -> List[PointList]:
""" Clusters points """
points = points.points # deconstruct
centres = centres.points # deconstruct
k = len(centres) # Number of centres
clusters = create_k_point_lists(k) # Create k clusters
for point in points: # Iterate over each point
distances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centre
min_distance = min(distances) # Get the shortest distance
centre_index = distances.index(min_distance) # Get the centre with this distance
centre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distance
clusters[centre_index].append(point) # Add the point to the list of the cluster with the centre
point.color = centre.color # colorize the point in the color of the centre (visualization only)
return clusters # New clusters
def calculate_new_centres(clusters: List[PointList]):
""" Calculates the new centres of k clusters """
new_centres = PointList(marker="o") # Create a new point list for the centres
for cluster in clusters: # Iterate over each cluster
new_centres.append(
Point(
x=cluster.x_avg, # New x coordinate equals the average x value of all points in the cluster
y=cluster.y_avg, # New y coordinate equals the average y value of all points in the cluster
color=cluster.colors[0], # The color of the first point
magnitude=150 # Centres are display a bit larger to identify them visually
)
)
return new_centres
def plot_styling():
plt.figure(facecolor="#111827") # Background color
axis = plt.gca() # Create axis object
axis.set_facecolor("#111827") # axis background color
axis.spines['bottom'].set_color('white') # axis bottom color
axis.spines['top'].set_color('white') # axis top color
axis.spines['right'].set_color('white') # axis right color
axis.spines['left'].set_color('white') # axis left color
axis.tick_params(axis='x', colors='white') # x axis ticks color
axis.tick_params(axis='y', colors='white') # y axis ticks color
def k_means(points: PointList, centres: PointList):
""" The actual algorithm. With lots of stuff only for visualization """
difference = 1 # initial difference
n = 1 # initial iteration counter
while abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zero
new_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centres
new_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their points
difference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one
# Animation only
plot_styling() # Apply some styles
points.plot() # Plot the points
centres.plot() # Plot the centres
plt.title(f'Iteration {n}', color="white") # Create a title based on the iteration
frame_bytes = BytesIO() # Create an in-memory bytes object
plt.savefig(frame_bytes) # Save the plot to the in-memory bytes
frames.append(frame_bytes) # Append the plot to the frames
plt.close("all") # Close all plots to free the memory
# / Animation only
centres = new_centres # Set the new centres as the centres for the next iteration
n += 1 # Increment the iteration count
with imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the frames
for frame in frames: # Iterate over each frame
image = imageio.imread(frame) # Read the bytes
writer.append_data(image) # Append it to the animation
if __name__ == '__main__':
points = random_points(N_POINTS)
centres = create_random_cluster_centres(k=N_CLUSTERS)
frames = list()
# Initial View
plot_styling()
points.plot()
centres.plot()
plt.title("Initialization", color="white")
frame_bytes = BytesIO()
plt.savefig(frame_bytes)
frames.append(frame_bytes)
# Run the Algorithm
k_means(points, centres)
Explanation
The code above should run without any issues. If you encounter any though, feel free to commit on the page bottom. However, if you modified the code and run into issues the following explanations might help.
Libraries
First of all we need some libraries. We use:
- the
__future__
package to be able to use type hints before their definition - the Python built-in module
math
for a few mathematical methods - the
sys
package to determine the epsilon of the float datatype io
to store byte objects inmemory- the
typing
package for type hints imageio
to create and save the frames of our animation.pyplot
frommatplotlib
-as the name let's one assume- for plotting the data, or here, the pictures as they're nothing else than matrices themselves. By conventionmatplotlib.pyplot
is imported asplt
numpy
to ease the array handling and outsource some operations to the way faster C-based libraries included. By convention, we importnumpy
asnp
to reduce the length of the code lines
from __future__ import annotations
import math
import sys
from io import BytesIO
from typing import List
import imageio
import matplotlib.pyplot as plt
import numpy as np
N_POINTS = 500 # The number of points to cluster
N_CLUSTERS = 5 # The number of clusters
COLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each cluster
EPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.
We also define a few static variables, respectively:
N_POINTS
for the number of dots we want to display in our k-means algorithmN_CLUSTERS
for the number of clusters (equals k)COLORS
for the colors of the clustersEPSILON
to calculate zero with the machine inaccuracy correctly
Point in a 2-dimensional Plane
We start the code by defining a class to repsent a point in a 2-dimensional plane. It has a single and simple method, to determine its Euclidean distance to another point.
Short recap about the euclidean distance: The Euclidean distance of two points
p
and q
is calculated using the square root of the sum of the squared delta
between their dimensional coordinates; or as a formula:
In other words, as we only need 2 dimensions for our points, it is merely the
square root of their squared x
and y
delta
, as shown in the code below.
class Point:
def __init__(self, x: float, y: float, color="grey", magnitude=20):
""" Point constructor """
self.x = x # The x coordinate
self.y = y # The y coordinate
self.color = color # The points colors
self.magnitude = magnitude # The magnitude of the point (size in pyplot)
def distance_to_point(self, point: Point):
""" Calculates the points distance from another point """
delta_x = self.x - point.x # The difference in x direction
delta_y = self.y - point.y # The difference in y direction
return math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other point
Additionally, we create an attribute color and magnitude to make it easier to visually distinguish the points.
Point List
To ease handling these points a bit, we create a data structure called
PointList
which is a mere interpretation of a list with a few methods
specifically tailored to handling and plotting points.
class PointList:
def __init__(self, points: List[Point] = None, marker: str = "x"):
""" PointList Constructor """
if points:
self.points = points # points may be a list of Points
else:
self.points = [] # or None by default
self.marker = marker # The marker (symbol in the pyplot)
@property
def x_values(self):
""" A list of the x values of all points in the list """
return [point.x for point in self.points]
@property
def y_values(self):
""" A list of the y values of all points in the list """
return [point.y for point in self.points]
@property
def colors(self):
""" A list of the colors of all points in the list """
return [point.color for point in self.points]
@property
def magnitudes(self):
""" A list of the magnitudes of all points in the list """
return [point.magnitude for point in self.points]
def plot(self):
""" Returns a scatter plot of itself """
return plt.scatter(
x=self.x_values,
y=self.y_values,
color=self.colors,
marker=self.marker,
s=self.magnitudes
)
def append(self, point):
""" Adds a point to the PointList """
self.points.append(point)
def len(self):
""" Returns the length of the PointsList """
return len(self.points)
@property
def x_sum(self):
""" The sum of the x values of all points in the list """
return sum(self.x_values)
@property
def y_sum(self):
""" The sum of the y values of all points in the list """
return sum(self.y_values)
@property
def x_avg(self):
""" The average of the x values of all points in the list """
return self.x_sum / self.len()
@property
def y_avg(self):
""" The average of the y values of all points in the list """
return self.y_sum / self.len()
def difference(self, other_points_list: PointList) -> float:
""" Returns the distance between points of two lists """
differences = []
for own_point, list_point in zip(self.points, other_points_list.points):
differences.append(
(own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2
)
return math.sqrt(sum(differences))
We make use of the Python built-in decorator property
for various properties
of this list type. It is mere syntactic sugar and transforms class methods to behave
like class attributes.
Create Random PointList
To create some data we can cluster, we write two simple functions that create
random Points
in a 1x1 sized plane, also we create a PointList
that contains
n
of these random points.
def random_point(**kwargs):
""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """
x = np.random.rand(1)
y = np.random.rand(1)
return Point(x, y, **kwargs)
def random_points(n: int):
""" Create a list of n random points """
points = PointList()
for _ in range(n):
points.append(random_point())
return points
Create Random K-Clusters
The plain k-means algorithm is initialized by using randomly placed centres so we create a helper function that does exactly that.
def create_random_cluster_centres(k: int):
""" Returns k random centres """
centres = PointList(marker="o")
for color, _ in zip(COLORS, range(k)):
centres.append(random_point(color=color, magnitude=150))
return centres
Besides being represented by another marker, there is basically no difference between a point and a cluster centre, which is why we reuse creating a random point.
Create K Point Lists
Another helper function we're going to use is the create_k_point_lists
method.
It will return a PointList
k
times, to ease handling them.
def create_k_point_lists(k: int):
""" Returns k instances of PointLists"""
point_lists = []
for _ in range(k):
point_lists.append(PointList())
return point_lists
Cluster Points
The first real submethod of k-means is the cluster points method. It receives a
PointList
of the points and a PointList
of the centres and assigns each
point to a new cluster list which is based on the centre it is the closest to
def cluster_points(points: PointList, centres: PointList) -> List[PointList]:
""" Clusters points """
points = points.points # deconstruct
centres = centres.points # deconstruct
k = len(centres) # Number of centres
clusters = create_k_point_lists(k) # Create k clusters
for point in points: # Iterate over each point
distances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centre
min_distance = min(distances) # Get the shortest distance
centre_index = distances.index(min_distance) # Get the centre with this distance
centre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distance
clusters[centre_index].append(point) # Add the point to the list of the cluster with the centre
point.color = centre.color # colorize the point in the color of the centre (visualization only)
return clusters # New clusters
Calculate New Centres
Another submethod of k-means is to calculate_new_centres
based on the average
x
and y
coordinates of all the points in the cluster.
def calculate_new_centres(clusters: List[PointList]):
""" Calculates the new centres of k clusters """
new_centres = PointList(marker="o") # Create a new point list for the centres
for cluster in clusters: # Iterate over each cluster
new_centres.append(
Point(
x=cluster.x_avg, # New x coordinate equals the average x value of all points in the cluster
y=cluster.y_avg, # New y coordinate equals the average y value of all points in the cluster
color=cluster.colors[0], # The color of the first point
magnitude=150 # Centres are display a bit larger to identify them visually
)
)
return new_centres
Plot Styling
Not much to say about this code segment. Its simply applying some styling settings for the plot to meet the general styling of this page.
def plot_styling():
plt.figure(facecolor="#111827") # Background color
axis = plt.gca() # Create axis object
axis.set_facecolor("#111827") # axis background color
axis.spines['bottom'].set_color('white') # axis bottom color
axis.spines['top'].set_color('white') # axis top color
axis.spines['right'].set_color('white') # axis right color
axis.spines['left'].set_color('white') # axis left color
axis.tick_params(axis='x', colors='white') # x axis ticks color
axis.tick_params(axis='y', colors='white') # y axis ticks color
K-Means Algorithm
Finally we put all the code together and run the k_means
algorithm. If the
files gets called directly the if __name__ == '__main__':
condition will be
met and the code listed below it will be executed.
def k_means(points: PointList, centres: PointList):
""" The actual algorithm. With lots of stuff only for visualization """
difference = 1 # initial difference
n = 1 # initial iteration counter
while abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zero
new_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centres
new_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their points
difference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one
centres = new_centres # Set the new centres as the centres for the next iteration
n += 1 # Increment the iteration count
if __name__ == '__main__':
points = random_points(N_POINTS)
centres = create_random_cluster_centres(k=N_CLUSTERS)
# Run the Algorithm
k_means(points, centres)
We create a PointList
of N_POINTS
times random point for the points to
cluster and initialize k-means with N_CLUSTERS
times a random point as the
centres. When k-means is ran there are 3 steps that are repeated until the
centres of each cluster do not move anymore (in relation to the previous
iteration):
- Cluster the points by assigning them to the closest centre
- Calculate the new centres by getting the average position of each point in the cluster
- Calculate the relative distance of the centres in relation to the preivous iteration determine the termination condition is met
Create the animation
To create an animation of these, we take a kind of snapshot of each iteration at specific points in the algorithm. This modifies the last part of the code slightly:
def k_means(points: PointList, centres: PointList):
""" The actual algorithm. With lots of stuff only for visualization """
difference = 1 # initial difference
n = 1 # initial iteration counter
while abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zero
new_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centres
new_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their points
difference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one
# Animation only
plot_styling() # Apply some styles
points.plot() # Plot the points
centres.plot() # Plot the centres
plt.title(f'Iteration {n}', color="white") # Create a title based on the iteration
frame_bytes = BytesIO() # Create an in-memory bytes object
plt.savefig(frame_bytes) # Save the plot to the in-memory bytes
frames.append(frame_bytes) # Append the plot to the frames
plt.close("all") # Close all plots to free the memory
# / Animation only
centres = new_centres # Set the new centres as the centres for the next iteration
n += 1 # Increment the iteration count
with imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the frames
for frame in frames: # Iterate over each frame
image = imageio.imread(frame) # Read the bytes
writer.append_data(image) # Append it to the animation
if __name__ == '__main__':
points = random_points(N_POINTS)
centres = create_random_cluster_centres(k=N_CLUSTERS)
frames = list()
# Initial View
plot_styling()
points.plot()
centres.plot()
plt.title("Initialization", color="white")
frame_bytes = BytesIO()
plt.savefig(frame_bytes)
frames.append(frame_bytes)
# Run the Algorithm
k_means(points, centres)