Created
June 9, 2014 22:33
-
-
Save Lambdanaut/bc7dfc4a217a7f323e35 to your computer and use it in GitHub Desktop.
Integer Vector Quantization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ Classifies a point as belonging to a cluster of points """ | |
| from math import sqrt | |
| from random import shuffle | |
| WIDTH = 28 | |
| HEIGHT = 28 | |
| POINT_TO_CLASSIFY = (2,4) | |
| DIAGONAL_LINE = [(x,x) for x in range(0,WIDTH)] | |
| DISTRIBUTION = [ | |
| # CLUSTER | |
| (2,4),(3,2),(5,5),(6,4), | |
| # CLUSTER | |
| (10,10),(10,12),(14,10),(15,8),(16,10), | |
| # CLUSTER | |
| (20,2),(20,3),(20,4),(23,3),(25,2), | |
| # CLUSTER | |
| (26,26),(26,27),(27,22),(27,23),(27,27), | |
| ] | |
| def classify(point, centroids): | |
| closest_centroid = centroids[0] | |
| closest_distance = distance(point, closest_centroid) | |
| for centroid in centroids[1:]: | |
| cent_dist = distance(point, centroid) | |
| if cent_dist < closest_distance: | |
| closest_distance = cent_dist | |
| closest_centroid = centroid | |
| return closest_centroid | |
| def vq(points, outputs=4): | |
| shuffle(points) | |
| # Seed centroids | |
| centroids = [] | |
| for p in range(0, outputs): | |
| centroids.append(points.pop()) | |
| for point in points: | |
| closest_distance = distance(point, centroids[0]) | |
| closest_index = 0 | |
| for centroid_i in range(1, len(centroids)): | |
| cent_dist = distance(point, centroids[centroid_i]) | |
| if cent_dist < closest_distance: | |
| closest_distance = cent_dist | |
| closest_index = centroid_i | |
| centroids[closest_index] = vec_round(center(centroids[closest_index], point)) | |
| return centroids | |
| def center(p1, p2): | |
| return scale(vec_abs(add(p1, p2)), 0.5) | |
| def add(p1, p2): | |
| x1, y1 = p1 | |
| x2, y2 = p2 | |
| return (x2+x1, y2+y1) | |
| def vec_abs(p): | |
| x, y = p | |
| return (abs(x), abs(y)) | |
| def vec_round(p): | |
| x, y = p | |
| return (round(x), round(y)) | |
| def scale(point, scalar): | |
| x, y = point | |
| return (x*scalar, y*scalar) | |
| def distance(p1, p2): | |
| x1, y1 = p1 | |
| x2, y2 = p2 | |
| return sqrt( (x2 - x1)**2 + (y2 - y1)**2) | |
| def show_points(points): | |
| to_print = '' | |
| for y in range(0, HEIGHT): | |
| for x in range(0, WIDTH): | |
| if (x,y) in points: to_print += '[x]' | |
| else: to_print += '[ ]' | |
| to_print += '\n' | |
| return to_print | |
| def print_vq(points, classify_point=None): | |
| print ('Original Data:') | |
| print (show_points(points)) | |
| print ('Quantized Data:') | |
| quantized = vq(points) | |
| print {x: quantized[x] for x in range(0, len(quantized))} | |
| print (show_points(quantized)) | |
| if classify_point: | |
| classified = classify(classify_point, quantized) | |
| print ('Classified {} as part of cluster {}').format(classify_point, quantized.index(classified)) | |
| print_vq(DISTRIBUTION, POINT_TO_CLASSIFY) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment