Last active
April 4, 2019 07:10
-
-
Save knok/9f0429d3954017a289c1b5ef76661657 to your computer and use it in GitHub Desktop.
ChainerCVを用いて猫画像を分類する ref: https://qiita.com/knok/items/8b1919e2a8b71d9134c9
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/nogpu/bin/python | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import chainer | |
| from chainercv.datasets import voc_detection_label_names | |
| from chainercv.links import SSD300 | |
| from chainercv import utils | |
| import os | |
| def main(): | |
| chainer.config.train = False | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--gpu', type=int, default=-1) | |
| parser.add_argument('--pretrained_model', default='voc0712') | |
| parser.add_argument('src_dir') | |
| parser.add_argument('dst_dir') | |
| args = parser.parse_args() | |
| model = SSD300( | |
| n_fg_class=len(voc_detection_label_names), | |
| pretrained_model=args.pretrained_model) | |
| if args.gpu >= 0: | |
| model.to_gpu(args.gpu) | |
| chainer.cuda.get_device(args.gpu).use() | |
| file_lists = [] | |
| for f in os.listdir(args.src_dir): | |
| if not f.startswith("."): | |
| file_lists.append(f) | |
| if not os.path.exists(args.dst_dir): | |
| os.mkdir(args.dst_dir) | |
| cat_id = voc_detection_label_names.index('cat') # 猫に相当するindexの取得 | |
| def has_cat(labels): | |
| for l in labels: | |
| if type(l) == int: | |
| if l == cat_id: | |
| return True | |
| for ll in l: # labelsが配列の配列を返すことがある | |
| if ll == cat_id: | |
| return True | |
| return False | |
| print("target file: %d files" % len(file_lists)) | |
| count = 0 | |
| for f in file_lists: | |
| fname = os.path.join(args.src_dir, f) | |
| img = utils.read_image(fname, color=True) | |
| bboxes, labels, scores = model.predict([img]) | |
| if has_cat(labels): | |
| dst_fname = os.path.join(args.dst_dir, f) | |
| os.rename(fname, dst_fname) | |
| count += 1 | |
| print("%d: move from %s to %s" % (count, fname, dst_fname)) | |
| print("%d files moved." % count) | |
| if __name__ == '__main__': | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/nogpu/bin/python | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import chainer | |
| from chainercv.datasets import voc_detection_label_names | |
| from chainercv.links import SSD300 | |
| from chainercv import utils | |
| import os | |
| def main(): | |
| chainer.config.train = False | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--gpu', type=int, default=-1) | |
| parser.add_argument('--pretrained_model', default='voc0712') | |
| parser.add_argument('src_dir') | |
| parser.add_argument('dst_dir') | |
| args = parser.parse_args() | |
| model = SSD300( | |
| n_fg_class=len(voc_detection_label_names), | |
| pretrained_model=args.pretrained_model) | |
| if args.gpu >= 0: | |
| model.to_gpu(args.gpu) | |
| chainer.cuda.get_device(args.gpu).use() | |
| file_lists = [] | |
| for f in os.listdir(args.src_dir): | |
| if not f.startswith("."): | |
| file_lists.append(f) | |
| if not os.path.exists(args.dst_dir): | |
| os.mkdir(args.dst_dir) | |
| cat_id = voc_detection_label_names.index('cat') # 猫に相当するindexの取得 | |
| def has_cat(labels): | |
| for l in labels: | |
| if type(l) == int: | |
| if l == cat_id: | |
| return True | |
| for ll in l: # labelsが配列の配列を返すことがある | |
| if ll == cat_id: | |
| return True | |
| return False | |
| print("target file: %d files" % len(file_lists)) | |
| count = 0 | |
| for f in file_lists: | |
| fname = os.path.join(args.src_dir, f) | |
| img = utils.read_image(fname, color=True) | |
| bboxes, labels, scores = model.predict([img]) | |
| if has_cat(labels): | |
| dst_fname = os.path.join(args.dst_dir, f) | |
| os.rename(fname, dst_fname) | |
| count += 1 | |
| print("%d: move from %s to %s" % (count, fname, dst_fname)) | |
| print("%d files moved." % count) | |
| if __name__ == '__main__': | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| $ python | |
| >>> from chainercv.datasets import voc_detection_label_names | |
| >>> voc_detection_label_names | |
| ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| $ python | |
| >>> from chainercv.datasets import voc_detection_label_names | |
| >>> voc_detection_label_names | |
| ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment