Last active
September 14, 2022 09:20
-
-
Save dobrosketchkun/08ce3c04382ea31f81d497f55436e754 to your computer and use it in GitHub Desktop.
Get a pallete from an image
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Shamelessly ripped with some meat from a number of unknown sources, sorry. | |
| ''' | |
| from math import sqrt | |
| import pandas as pd | |
| import numpy as np | |
| import random | |
| try: | |
| import Image | |
| except ImportError: | |
| from PIL import Image | |
| class Point: | |
| def __init__(self, coordinates): | |
| self.coordinates = coordinates | |
| class Cluster: | |
| def __init__(self, center, points): | |
| self.center = center | |
| self.points = points | |
| class KMeans: | |
| def __init__(self, n_clusters, min_diff = 1): | |
| self.n_clusters = n_clusters | |
| self.min_diff = min_diff | |
| def calculate_center(self, points): | |
| n_dim = len(points[0].coordinates) | |
| vals = [0.0 for i in range(n_dim)] | |
| for p in points: | |
| for i in range(n_dim): | |
| vals[i] += p.coordinates[i] | |
| coords = [(v / len(points)) for v in vals] | |
| return Point(coords) | |
| def assign_points(self, clusters, points): | |
| plists = [[] for i in range(self.n_clusters)] | |
| for p in points: | |
| smallest_distance = float('inf') | |
| for i in range(self.n_clusters): | |
| distance = euclidean(p, clusters[i].center) | |
| if distance < smallest_distance: | |
| smallest_distance = distance | |
| idx = i | |
| plists[idx].append(p) | |
| return plists | |
| def fit(self, points): | |
| clusters = [Cluster(center=p, points=[p]) for p in random.sample(points, self.n_clusters)] | |
| while True: | |
| plists = self.assign_points(clusters, points) | |
| diff = 0 | |
| for i in range(self.n_clusters): | |
| if not plists[i]: | |
| continue | |
| old = clusters[i] | |
| center = self.calculate_center(plists[i]) | |
| new = Cluster(center, plists[i]) | |
| clusters[i] = new | |
| diff = max(diff, euclidean(old.center, new.center)) | |
| if diff < self.min_diff: | |
| break | |
| return clusters | |
| def euclidean(p, q): | |
| n_dim = len(p.coordinates) | |
| return sqrt(sum([ | |
| (p.coordinates[i] - q.coordinates[i]) ** 2 for i in range(n_dim) | |
| ])) | |
| def get_points(image_path): | |
| img = Image.open(image_path) | |
| img.thumbnail((200, 400)) | |
| img = img.convert("RGB") | |
| w, h = img.size | |
| points = [] | |
| for count, color in img.getcolors(w * h): | |
| for _ in range(count): | |
| points.append(Point(color)) | |
| return points | |
| def rgb_to_hex(rgb): | |
| return '#%s' % ''.join(('%02x' % p for p in rgb)) | |
| def get_colors(filename, n_colors=3): | |
| points = get_points(filename) | |
| clusters = KMeans(n_clusters=n_colors).fit(points) | |
| clusters.sort(key=lambda c: len(c.points), reverse = True) | |
| # rgbs = [map(int, c.center.coordinates) for c in clusters] | |
| rgbs = [[int(_) for _ in c.center.coordinates] for c in clusters] | |
| return list(map(rgb_to_hex, rgbs)), rgbs | |
| ####################### | |
| number_of_colors = 5 | |
| ####################### | |
| colors, rgbs = get_colors('/content/165742854.jpg', n_colors=number_of_colors) | |
| def display_pal( | |
| w=50, h=50, save_to_file=False, filename="color_palette", extension="jpg" | |
| ): | |
| img = Image.new("RGB", size=(w * number_of_colors, h)) | |
| arr = np.asarray(img).copy() | |
| for i in range(number_of_colors): | |
| c = rgbs[i] | |
| arr[:, i * h : (i + 1) * h, :] = c | |
| img = Image.fromarray(arr, "RGB") | |
| img.show() | |
| if save_to_file: | |
| img.save(f"{filename}.{extension}") | |
| return img | |
| print(colors) | |
| display_pal() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment