Learn how to create and visualize the k-means algorithm - a very basic clustering algorithm that is often taugth in introductory data science classes

In this article we're going to look into k-means and how it can be visualized using Python. Below we have the code necessary to create an animation. You may simply copy and paste it to create your own animation of k-means. However, we also explain each of the functions below it in the explanation part.

Note: We sacrifice some modularity/extensiblity and break certain code principles in this example to maximize its readability and minimize the amount of code required to reach the goal described in the title.

The K-Means Animation

Image

Code to visualize K-Means

from __future__ import annotations
import math
import sys
from io import BytesIO
from typing import List
import imageio
import matplotlib.pyplot as plt
import numpy as np
N_POINTS = 500 # The number of points to cluster
N_CLUSTERS = 5 # The number of clusters
COLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each cluster
EPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.
class Point:
def __init__(self, x: float, y: float, color="grey", magnitude=20):
""" Point constructor """
self.x = x # The x coordinate
self.y = y # The y coordinate
self.color = color # The points colors
self.magnitude = magnitude # The magnitude of the point (size in pyplot)
def distance_to_point(self, point: Point):
""" Calculates the points distance from another point """
delta_x = self.x - point.x # The difference in x direction
delta_y = self.y - point.y # The difference in y direction
return math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other point
class PointList:
def __init__(self, points: List[Point] = None, marker: str = "x"):
""" PointList Constructor """
if points:
self.points = points # points may be a list of Points
else:
self.points = [] # or None by default
self.marker = marker # The marker (symbol in the pyplot)
@property
def x_values(self):
""" A list of the x values of all points in the list """
return [point.x for point in self.points]
@property
def y_values(self):
""" A list of the y values of all points in the list """
return [point.y for point in self.points]
@property
def colors(self):
""" A list of the colors of all points in the list """
return [point.color for point in self.points]
@property
def magnitudes(self):
""" A list of the magnitudes of all points in the list """
return [point.magnitude for point in self.points]
def plot(self):
""" Returns a scatter plot of itself """
return plt.scatter(
x=self.x_values,
y=self.y_values,
color=self.colors,
marker=self.marker,
s=self.magnitudes
)
def append(self, point):
""" Adds a point to the PointList """
self.points.append(point)
def len(self):
""" Returns the length of the PointsList """
return len(self.points)
@property
def x_sum(self):
""" The sum of the x values of all points in the list """
return sum(self.x_values)
@property
def y_sum(self):
""" The sum of the y values of all points in the list """
return sum(self.y_values)
@property
def x_avg(self):
""" The average of the x values of all points in the list """
return self.x_sum / self.len()
@property
def y_avg(self):
""" The average of the y values of all points in the list """
return self.y_sum / self.len()
def difference(self, other_points_list: PointList) -> float:
""" Returns the distance between points of two lists """
differences = []
for own_point, list_point in zip(self.points, other_points_list.points):
differences.append(
(own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2
)
return math.sqrt(sum(differences))
def random_point(**kwargs):
""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """
x = np.random.rand(1)
y = np.random.rand(1)
return Point(x, y, **kwargs)
def random_points(n: int):
""" Create a list of n random points """
points = PointList()
for _ in range(n):
points.append(random_point())
return points
def create_random_cluster_centres(k: int):
""" Returns k random centres """
centres = PointList(marker="o")
for color, _ in zip(colors, range(k)):
centres.append(random_point(color=color, magnitude=150))
return centres
def create_k_point_lists(k: int):
""" Returns k instances of PointLists"""
point_lists = []
for _ in range(k):
point_lists.append(PointList())
return point_lists
def cluster_points(points: PointList, centres: PointList) -> List[PointList]:
""" Clusters points """
points = points.points # deconstruct
centres = centres.points # deconstruct
k = len(centres) # Number of centres
clusters = create_k_point_lists(k) # Create k clusters
for point in points: # Iterate over each point
distances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centre
min_distance = min(distances) # Get the shortest distance
centre_index = distances.index(min_distance) # Get the centre with this distance
centre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distance
clusters[centre_index].append(point) # Add the point to the list of the cluster with the centre
point.color = centre.color # colorize the point in the color of the centre (visualization only)
return clusters # New clusters
def calculate_new_centres(clusters: List[PointList]):
""" Calculates the new centres of k clusters """
new_centres = PointList(marker="o") # Create a new point list for the centres
for cluster in clusters: # Iterate over each cluster
new_centres.append(
Point(
x=cluster.x_avg, # New x coordinate equals the average x value of all points in the cluster
y=cluster.y_avg, # New y coordinate equals the average y value of all points in the cluster
color=cluster.colors[0], # The color of the first point
magnitude=150 # Centres are display a bit larger to identify them visually
)
)
return new_centres
def plot_styling():
plt.figure(facecolor="#111827") # Background color
axis = plt.gca() # Create axis object
axis.set_facecolor("#111827") # axis background color
axis.spines['bottom'].set_color('white') # axis bottom color
axis.spines['top'].set_color('white') # axis top color
axis.spines['right'].set_color('white') # axis right color
axis.spines['left'].set_color('white') # axis left color
axis.tick_params(axis='x', colors='white') # x axis ticks color
axis.tick_params(axis='y', colors='white') # y axis ticks color
def k_means(points: PointList, centres: PointList):
""" The actual algorithm. With lots of stuff only for visualization """
difference = 1 # initial difference
n = 1 # initial iteration counter
while abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zero
new_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centres
new_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their points
difference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one
# Animation only
plot_styling() # Apply some styles
points.plot() # Plot the points
centres.plot() # Plot the centres
plt.title(f'Iteration {n}', color="white") # Create a title based on the iteration
frame_bytes = BytesIO() # Create an in-memory bytes object
plt.savefig(frame_bytes) # Save the plot to the in-memory bytes
frames.append(frame_bytes) # Append the plot to the frames
plt.close("all") # Close all plots to free the memory
# / Animation only
centres = new_centres # Set the new centres as the centres for the next iteration
n += 1 # Increment the iteration count
with imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the frames
for frame in frames: # Iterate over each frame
image = imageio.imread(frame) # Read the bytes
writer.append_data(image) # Append it to the animation
if __name__ == '__main__':
points = random_points(N_POINTS)
centres = create_random_cluster_centres(k=N_CLUSTERS)
frames = list()
# Initial View
plot_styling()
points.plot()
centres.plot()
plt.title("Initialization", color="white")
frame_bytes = BytesIO()
plt.savefig(frame_bytes)
frames.append(frame_bytes)
# Run the Algorithm
k_means(points, centres)

Explanation

The code above should run without any issues. If you encounter any though, feel free to commit on the page bottom. However, if you modified the code and run into issues the following explanations might help.

Libraries

First of all we need some libraries. We use:

  • the __future__ package to be able to use type hints before their definition
  • the Python built-in module math for a few mathematical methods
  • the sys package to determine the epsilon of the float datatype
  • io to store byte objects inmemory
  • the typing package for type hints
  • imageio to create and save the frames of our animation.
  • pyplot from matplotlib -as the name let's one assume- for plotting the data, or here, the pictures as they're nothing else than matrices themselves. By convention matplotlib.pyplot is imported as plt
  • numpy to ease the array handling and outsource some operations to the way faster C-based libraries included. By convention, we import numpy as np to reduce the length of the code lines
from __future__ import annotations
import math
import sys
from io import BytesIO
from typing import List
import imageio
import matplotlib.pyplot as plt
import numpy as np
N_POINTS = 500 # The number of points to cluster
N_CLUSTERS = 5 # The number of clusters
COLORS = ["#3498db", "#2ecc71", "#f1c40f", "#9b59b6", "#e74c3c"] # The colors of each cluster
EPSILON = sys.float_info.epsilon # A value that considers machine inaccuracy when calculating zero. Depends on the precision of the data type.

We also define a few static variables, respectively:

  • N_POINTS for the number of dots we want to display in our k-means algorithm
  • N_CLUSTERS for the number of clusters (equals k)
  • COLORS for the colors of the clusters
  • EPSILON to calculate zero with the machine inaccuracy correctly

Point in a 2-dimensional Plane

We start the code by defining a class to repsent a point in a 2-dimensional plane. It has a single and simple method, to determine its Euclidean distance to another point.

Short recap about the euclidean distance: The Euclidean distance of two points p and q is calculated using the square root of the sum of the squared delta between their dimensional coordinates; or as a formula:

In other words, as we only need 2 dimensions for our points, it is merely the square root of their squared x and y delta, as shown in the code below.

class Point:
def __init__(self, x: float, y: float, color="grey", magnitude=20):
""" Point constructor """
self.x = x # The x coordinate
self.y = y # The y coordinate
self.color = color # The points colors
self.magnitude = magnitude # The magnitude of the point (size in pyplot)
def distance_to_point(self, point: Point):
""" Calculates the points distance from another point """
delta_x = self.x - point.x # The difference in x direction
delta_y = self.y - point.y # The difference in y direction
return math.sqrt(delta_x ** 2 + delta_y ** 2) # The Euclidean distance to the other point

Additionally, we create an attribute color and magnitude to make it easier to visually distinguish the points.

Point List

To ease handling these points a bit, we create a data structure called PointList which is a mere interpretation of a list with a few methods specifically tailored to handling and plotting points.

class PointList:
def __init__(self, points: List[Point] = None, marker: str = "x"):
""" PointList Constructor """
if points:
self.points = points # points may be a list of Points
else:
self.points = [] # or None by default
self.marker = marker # The marker (symbol in the pyplot)
@property
def x_values(self):
""" A list of the x values of all points in the list """
return [point.x for point in self.points]
@property
def y_values(self):
""" A list of the y values of all points in the list """
return [point.y for point in self.points]
@property
def colors(self):
""" A list of the colors of all points in the list """
return [point.color for point in self.points]
@property
def magnitudes(self):
""" A list of the magnitudes of all points in the list """
return [point.magnitude for point in self.points]
def plot(self):
""" Returns a scatter plot of itself """
return plt.scatter(
x=self.x_values,
y=self.y_values,
color=self.colors,
marker=self.marker,
s=self.magnitudes
)
def append(self, point):
""" Adds a point to the PointList """
self.points.append(point)
def len(self):
""" Returns the length of the PointsList """
return len(self.points)
@property
def x_sum(self):
""" The sum of the x values of all points in the list """
return sum(self.x_values)
@property
def y_sum(self):
""" The sum of the y values of all points in the list """
return sum(self.y_values)
@property
def x_avg(self):
""" The average of the x values of all points in the list """
return self.x_sum / self.len()
@property
def y_avg(self):
""" The average of the y values of all points in the list """
return self.y_sum / self.len()
def difference(self, other_points_list: PointList) -> float:
""" Returns the distance between points of two lists """
differences = []
for own_point, list_point in zip(self.points, other_points_list.points):
differences.append(
(own_point.x - list_point.x) ** 2 + (own_point.y - list_point.y) ** 2
)
return math.sqrt(sum(differences))

We make use of the Python built-in decorator property for various properties of this list type. It is mere syntactic sugar and transforms class methods to behave like class attributes.

Create Random PointList

To create some data we can cluster, we write two simple functions that create random Points in a 1x1 sized plane, also we create a PointList that contains n of these random points.

def random_point(**kwargs):
""" Create a random point with coordinates x=[0, 1] y=[0, 1]. Bypasses arguments """
x = np.random.rand(1)
y = np.random.rand(1)
return Point(x, y, **kwargs)
def random_points(n: int):
""" Create a list of n random points """
points = PointList()
for _ in range(n):
points.append(random_point())
return points

Create Random K-Clusters

The plain k-means algorithm is initialized by using randomly placed centres so we create a helper function that does exactly that.

def create_random_cluster_centres(k: int):
""" Returns k random centres """
centres = PointList(marker="o")
for color, _ in zip(COLORS, range(k)):
centres.append(random_point(color=color, magnitude=150))
return centres

Besides being represented by another marker, there is basically no difference between a point and a cluster centre, which is why we reuse creating a random point.

Create K Point Lists

Another helper function we're going to use is the create_k_point_lists method. It will return a PointList k times, to ease handling them.

def create_k_point_lists(k: int):
""" Returns k instances of PointLists"""
point_lists = []
for _ in range(k):
point_lists.append(PointList())
return point_lists

Cluster Points

The first real submethod of k-means is the cluster points method. It recieves a PointList of the points and a PointList of the centres and assigns each point to a new cluster list which is based on the centre it is the closest to

def cluster_points(points: PointList, centres: PointList) -> List[PointList]:
""" Clusters points """
points = points.points # deconstruct
centres = centres.points # deconstruct
k = len(centres) # Number of centres
clusters = create_k_point_lists(k) # Create k clusters
for point in points: # Iterate over each point
distances = [point.distance_to_point(centre) for centre in centres] # calculate the distance for each point to each centre
min_distance = min(distances) # Get the shortest distance
centre_index = distances.index(min_distance) # Get the centre with this distance
centre = centres[centre_index] # The (new) centre of this point is the centre with the shortest distance
clusters[centre_index].append(point) # Add the point to the list of the cluster with the centre
point.color = centre.color # colorize the point in the color of the centre (visualization only)
return clusters # New clusters

Calculate New Centres

Another submethod of k-means is to calculate_new_centres based on the average x and y coordinates of all the points in the cluster.

def calculate_new_centres(clusters: List[PointList]):
""" Calculates the new centres of k clusters """
new_centres = PointList(marker="o") # Create a new point list for the centres
for cluster in clusters: # Iterate over each cluster
new_centres.append(
Point(
x=cluster.x_avg, # New x coordinate equals the average x value of all points in the cluster
y=cluster.y_avg, # New y coordinate equals the average y value of all points in the cluster
color=cluster.colors[0], # The color of the first point
magnitude=150 # Centres are display a bit larger to identify them visually
)
)
return new_centres

Plot Styling

Not much to say about this code segment. Its simply applying some styling settings for the plot to meet the general styling of this page.

def plot_styling():
plt.figure(facecolor="#111827") # Background color
axis = plt.gca() # Create axis object
axis.set_facecolor("#111827") # axis background color
axis.spines['bottom'].set_color('white') # axis bottom color
axis.spines['top'].set_color('white') # axis top color
axis.spines['right'].set_color('white') # axis right color
axis.spines['left'].set_color('white') # axis left color
axis.tick_params(axis='x', colors='white') # x axis ticks color
axis.tick_params(axis='y', colors='white') # y axis ticks color

K-Means Algorithm

Finally we put all the code together and run the k_means algorithm. If the files gets called directly the if __name__ == '__main__': condition will be met and the code listed below it will be executed.

def k_means(points: PointList, centres: PointList):
""" The actual algorithm. With lots of stuff only for visualization """
difference = 1 # initial difference
n = 1 # initial iteration counter
while abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zero
new_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centres
new_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their points
difference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one
centres = new_centres # Set the new centres as the centres for the next iteration
n += 1 # Increment the iteration count
if __name__ == '__main__':
points = random_points(N_POINTS)
centres = create_random_cluster_centres(k=N_CLUSTERS)
# Run the Algorithm
k_means(points, centres)

We create a PointList of N_POINTS times random point for the points to cluster and initialize k-means with N_CLUSTERS times a random point as the centres. When k-means is ran there are 3 steps that are repeated until the centres of each cluster do not move anymore (in relation to the previous iteration):

  1. Cluster the points by assigning them to the closest centre
  2. Calculate the new centres by getting the average position of each point in the cluster
  3. Calculate the relative distance of the centres in relation to the preivous iteration determine the termination condition is met

Create the animation

To create an animation of these, we take a kind of snapshot of each iteration at specific points in the algorithm. This modifies the last part of the code slightly:

def k_means(points: PointList, centres: PointList):
""" The actual algorithm. With lots of stuff only for visualization """
difference = 1 # initial difference
n = 1 # initial iteration counter
while abs(difference) >= abs(0 + EPSILON): # The shift of the centres in relation to the last iteration is greater than zero
new_clusters = cluster_points(points, centres) # Calculate new clusters based on the points and their centres
new_centres = calculate_new_centres(new_clusters) # Calculate the new positions of the centres of each cluster based on their points
difference = new_centres.difference(centres) # The distance between the centres of this iteration and the previous one
# Animation only
plot_styling() # Apply some styles
points.plot() # Plot the points
centres.plot() # Plot the centres
plt.title(f'Iteration {n}', color="white") # Create a title based on the iteration
frame_bytes = BytesIO() # Create an in-memory bytes object
plt.savefig(frame_bytes) # Save the plot to the in-memory bytes
frames.append(frame_bytes) # Append the plot to the frames
plt.close("all") # Close all plots to free the memory
# / Animation only
centres = new_centres # Set the new centres as the centres for the next iteration
n += 1 # Increment the iteration count
with imageio.get_writer('k-means.gif', mode='I', duration=0.5) as writer: # Create an animate gif from the frames
for frame in frames: # Iterate over each frame
image = imageio.imread(frame) # Read the bytes
writer.append_data(image) # Append it to the animation
if __name__ == '__main__':
points = random_points(N_POINTS)
centres = create_random_cluster_centres(k=N_CLUSTERS)
frames = list()
# Initial View
plot_styling()
points.plot()
centres.plot()
plt.title("Initialization", color="white")
frame_bytes = BytesIO()
plt.savefig(frame_bytes)
frames.append(frame_bytes)
# Run the Algorithm
k_means(points, centres)