Skip to content

Instantly share code, notes, and snippets.

@knok
Last active April 4, 2019 07:10
Show Gist options
  • Select an option

  • Save knok/9f0429d3954017a289c1b5ef76661657 to your computer and use it in GitHub Desktop.

Select an option

Save knok/9f0429d3954017a289c1b5ef76661657 to your computer and use it in GitHub Desktop.
ChainerCVを用いて猫画像を分類する ref: https://qiita.com/knok/items/8b1919e2a8b71d9134c9
#!/usr/nogpu/bin/python
# -*- coding: utf-8 -*-
import argparse
import chainer
from chainercv.datasets import voc_detection_label_names
from chainercv.links import SSD300
from chainercv import utils
import os
def main():
chainer.config.train = False
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--pretrained_model', default='voc0712')
parser.add_argument('src_dir')
parser.add_argument('dst_dir')
args = parser.parse_args()
model = SSD300(
n_fg_class=len(voc_detection_label_names),
pretrained_model=args.pretrained_model)
if args.gpu >= 0:
model.to_gpu(args.gpu)
chainer.cuda.get_device(args.gpu).use()
file_lists = []
for f in os.listdir(args.src_dir):
if not f.startswith("."):
file_lists.append(f)
if not os.path.exists(args.dst_dir):
os.mkdir(args.dst_dir)
cat_id = voc_detection_label_names.index('cat') # 猫に相当するindexの取得
def has_cat(labels):
for l in labels:
if type(l) == int:
if l == cat_id:
return True
for ll in l: # labelsが配列の配列を返すことがある
if ll == cat_id:
return True
return False
print("target file: %d files" % len(file_lists))
count = 0
for f in file_lists:
fname = os.path.join(args.src_dir, f)
img = utils.read_image(fname, color=True)
bboxes, labels, scores = model.predict([img])
if has_cat(labels):
dst_fname = os.path.join(args.dst_dir, f)
os.rename(fname, dst_fname)
count += 1
print("%d: move from %s to %s" % (count, fname, dst_fname))
print("%d files moved." % count)
if __name__ == '__main__':
main()
#!/usr/nogpu/bin/python
# -*- coding: utf-8 -*-
import argparse
import chainer
from chainercv.datasets import voc_detection_label_names
from chainercv.links import SSD300
from chainercv import utils
import os
def main():
chainer.config.train = False
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--pretrained_model', default='voc0712')
parser.add_argument('src_dir')
parser.add_argument('dst_dir')
args = parser.parse_args()
model = SSD300(
n_fg_class=len(voc_detection_label_names),
pretrained_model=args.pretrained_model)
if args.gpu >= 0:
model.to_gpu(args.gpu)
chainer.cuda.get_device(args.gpu).use()
file_lists = []
for f in os.listdir(args.src_dir):
if not f.startswith("."):
file_lists.append(f)
if not os.path.exists(args.dst_dir):
os.mkdir(args.dst_dir)
cat_id = voc_detection_label_names.index('cat') # 猫に相当するindexの取得
def has_cat(labels):
for l in labels:
if type(l) == int:
if l == cat_id:
return True
for ll in l: # labelsが配列の配列を返すことがある
if ll == cat_id:
return True
return False
print("target file: %d files" % len(file_lists))
count = 0
for f in file_lists:
fname = os.path.join(args.src_dir, f)
img = utils.read_image(fname, color=True)
bboxes, labels, scores = model.predict([img])
if has_cat(labels):
dst_fname = os.path.join(args.dst_dir, f)
os.rename(fname, dst_fname)
count += 1
print("%d: move from %s to %s" % (count, fname, dst_fname))
print("%d files moved." % count)
if __name__ == '__main__':
main()
$ python
>>> from chainercv.datasets import voc_detection_label_names
>>> voc_detection_label_names
('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
$ python
>>> from chainercv.datasets import voc_detection_label_names
>>> voc_detection_label_names
('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment