Skip to content

Instantly share code, notes, and snippets.

@dobrosketchkun
Last active September 14, 2022 09:20
Show Gist options
  • Select an option

  • Save dobrosketchkun/08ce3c04382ea31f81d497f55436e754 to your computer and use it in GitHub Desktop.

Select an option

Save dobrosketchkun/08ce3c04382ea31f81d497f55436e754 to your computer and use it in GitHub Desktop.
Get a pallete from an image
'''
Shamelessly ripped with some meat from a number of unknown sources, sorry.
'''
from math import sqrt
import pandas as pd
import numpy as np
import random
try:
import Image
except ImportError:
from PIL import Image
class Point:
def __init__(self, coordinates):
self.coordinates = coordinates
class Cluster:
def __init__(self, center, points):
self.center = center
self.points = points
class KMeans:
def __init__(self, n_clusters, min_diff = 1):
self.n_clusters = n_clusters
self.min_diff = min_diff
def calculate_center(self, points):
n_dim = len(points[0].coordinates)
vals = [0.0 for i in range(n_dim)]
for p in points:
for i in range(n_dim):
vals[i] += p.coordinates[i]
coords = [(v / len(points)) for v in vals]
return Point(coords)
def assign_points(self, clusters, points):
plists = [[] for i in range(self.n_clusters)]
for p in points:
smallest_distance = float('inf')
for i in range(self.n_clusters):
distance = euclidean(p, clusters[i].center)
if distance < smallest_distance:
smallest_distance = distance
idx = i
plists[idx].append(p)
return plists
def fit(self, points):
clusters = [Cluster(center=p, points=[p]) for p in random.sample(points, self.n_clusters)]
while True:
plists = self.assign_points(clusters, points)
diff = 0
for i in range(self.n_clusters):
if not plists[i]:
continue
old = clusters[i]
center = self.calculate_center(plists[i])
new = Cluster(center, plists[i])
clusters[i] = new
diff = max(diff, euclidean(old.center, new.center))
if diff < self.min_diff:
break
return clusters
def euclidean(p, q):
n_dim = len(p.coordinates)
return sqrt(sum([
(p.coordinates[i] - q.coordinates[i]) ** 2 for i in range(n_dim)
]))
def get_points(image_path):
img = Image.open(image_path)
img.thumbnail((200, 400))
img = img.convert("RGB")
w, h = img.size
points = []
for count, color in img.getcolors(w * h):
for _ in range(count):
points.append(Point(color))
return points
def rgb_to_hex(rgb):
return '#%s' % ''.join(('%02x' % p for p in rgb))
def get_colors(filename, n_colors=3):
points = get_points(filename)
clusters = KMeans(n_clusters=n_colors).fit(points)
clusters.sort(key=lambda c: len(c.points), reverse = True)
# rgbs = [map(int, c.center.coordinates) for c in clusters]
rgbs = [[int(_) for _ in c.center.coordinates] for c in clusters]
return list(map(rgb_to_hex, rgbs)), rgbs
#######################
number_of_colors = 5
#######################
colors, rgbs = get_colors('/content/165742854.jpg', n_colors=number_of_colors)
def display_pal(
w=50, h=50, save_to_file=False, filename="color_palette", extension="jpg"
):
img = Image.new("RGB", size=(w * number_of_colors, h))
arr = np.asarray(img).copy()
for i in range(number_of_colors):
c = rgbs[i]
arr[:, i * h : (i + 1) * h, :] = c
img = Image.fromarray(arr, "RGB")
img.show()
if save_to_file:
img.save(f"{filename}.{extension}")
return img
print(colors)
display_pal()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment