Created
April 1, 2018 11:27
-
-
Save remorsecs/00098a0e597f3f5ace32703ae298c6eb to your computer and use it in GitHub Desktop.
just for fun
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| from tqdm import tqdm | |
| def transforms(image_path): | |
| image = Image.open(str(image_path)).convert('L') | |
| image = np.array(image, dtype=float) | |
| image = np.resize(image, (192, 168)) | |
| image = np.reshape(image, (1, -1)) | |
| return image | |
| class CroppedYaleDataset: | |
| def __init__(self, root, num_train=35): | |
| self.root = Path(root) | |
| self.classes = [_dir for _dir in self.root.glob('*')] | |
| self.train_set = [(transforms(image), label) | |
| for label, _dir in enumerate(self.root.glob('*')) | |
| for j, image in enumerate(_dir.glob('*.pgm')) if j < num_train] | |
| self.test_set = [(transforms(image), label) | |
| for label, _dir in enumerate(self.root.glob('*')) | |
| for j, image in enumerate(_dir.glob('*.pgm')) if j >= num_train] | |
| def ssd(image1, image2): | |
| return np.sum(np.square(image1 - image2)) | |
| def sad(image1, image2): | |
| return np.sum(np.abs(image1 - image2)) | |
| def knn(dataset: CroppedYaleDataset, metric): | |
| corrects = 0 | |
| for test_image, test_label in tqdm(dataset.test_set): | |
| min_distance = float('inf') | |
| pred_label = 0 | |
| for train_image, train_label in dataset.train_set: | |
| distance = metric(train_image, test_image) | |
| if min_distance > distance: | |
| min_distance = distance | |
| pred_label = train_label | |
| if pred_label == test_label: | |
| corrects += 1 | |
| accuracy = corrects / len(dataset.test_set) | |
| print(f'{metric.__name__} accuracy: {accuracy:.4f}') | |
| if __name__ == '__main__': | |
| DATASET_PATH = 'CroppedYale' | |
| dataset = CroppedYaleDataset(DATASET_PATH) | |
| knn(dataset, ssd) | |
| knn(dataset, sad) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment