Skip to content

Instantly share code, notes, and snippets.

@jakeoung
Last active June 5, 2017 07:26
Show Gist options
  • Select an option

  • Save jakeoung/79b72bdeed8b08d97462ca6886991623 to your computer and use it in GitHub Desktop.

Select an option

Save jakeoung/79b72bdeed8b08d97462ca6886991623 to your computer and use it in GitHub Desktop.
jkplot library and python argparse usage of making file name and loading
########## main.py
import argparse
import time
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--niter', type=int, default=100, help='number of iteration')
args, unparsed = parser.parse_known_args()
# make output file name containing the key parameters
fResult = time.strftime('%m%d_')
dic = vars(args)
key_params = dic.keys()
for key in sorted(key_params):
fResult += '--' + str(key) + '_' + str(dic[key]) + '_'
parser = argparse.ArgumentParser(parents=[parser], conflict_handler='resolve')
parser.add_argument('--dResult', type=str, default='', help='directory path to save the result')
parser.add_argument('-v', '--verbose', action='count', help='enable verbose mode')
fResult_ = os.path.join(args.d, fResult)
# parser.add_argument('--fileResult', type=str, default='', help='file name to save the results')
# parser.add_argument('--useOurs', action='store_true', help='use ours')
args = parser.parse_args(namespace=args)
print(args, fResult)
## write result
if args.verbose:
pass
if os.path.exists(args.dResult):
np.save(fResult_, [array1, array2])
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib import rc
import pandas as pd
import re
import argparse
import random
def set_matplotlib():
# Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
# rc('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans'],'size':10})
# # Set the font used for MathJax - more on this later
# rc('mathtext',**{'default':'regular'})
params = {
'axes.labelsize': 8,
'font.size': 8,
'legend.fontsize': 10,
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'text.usetex': False,
'figure.figsize': [6., 4.] # instead of 4.5, 4.5
}
plt.rcParams.update(params)
class JKPlot(object):
"""
Plot utility class for general purpose based on argparse
"""
def __init__(self, dResult, parser, ext='npy', fList=None):
"""
dResult : list of directories containing result files
parser : argparse object
ext : extension to extract
"""
print(type(dResult))
assert(type(dResult) == list)
self.dResult = dResult
self.parser = parser
if fList == None:
self.fList = []
for res in dResult:
self.fList += [res+f for f in os.listdir(res) if f.endswith(ext)]
else:
self.fList = fList
def load_file(self, f):
"""
load file f, depending on the saved format
"""
res = np.load(f)
# ...
return res
def remove_file(self, **kwargs):
for f in self.fList:
fName = os.path.basename(f)
params = fName.split('_')
args, temp = self.parser.parse_known_args(params)
dic = vars(args)
bFind = True
for key in kwargs:
if dic[key] != kwargs[key]:
bFind = False
break
if bFind == True:
os.remove(f)
return
def load_result(self, **kwargs):
"""
load one result for given parameter
Usage:
>>> res = load_result(opt='sgd', lr=0.05, batchSize=32, dropout=0)
"""
for f in self.fList:
fName = os.path.basename(f)
params = fName.split('_')
args, temp = self.parser.parse_known_args(params)
dic = vars(args)
bFind = True
for key in kwargs:
if dic[key] != kwargs[key]:
bFind = False
break
if bFind == True:
print(args)
res = np.load(f)
return res, args
def generate(self, **kwargs):
"""
generate results of having parameters represented in kwargs
"""
for f in self.fList:
fName = os.path.basename(f)
params = fName.split('_')
args, temp = self.parser.parse_known_args(params)
dic = vars(args)
bFind = True
for key in kwargs:
if dic[key] != kwargs[key]:
bFind = False
break
if bFind == True:
res = self.load_file(f)
yield res, args
def varying_params(self, varying, **fixed):
"""
generate results while varying given parameters of kwargs
Args
fixed (dictionary) : fixed parameters
varying (list) : varying parameters
Returns
Yield the results with varying parameters.
Example
>>> for res, args in jp.varying_params(varying='batchProb', opt='rsgd', model='resnet18'):
>>> print(args)
"""
bFind = False
while (bFind == False):
# randomly choose one
r = random.randint(0, len(self.fList)-1)
f = self.fList[r]
fName = os.path.basename(f)
params = fName.split('_')
args, temp = self.parser.parse_known_args(params)
dic = vars(args)
equal_to_fixed = True
for key in fixed:
if fixed[key] != dic[key]:
equal_to_fixed = False
break
if equal_to_fixed == False:
continue
bFind = True
# force the fixed parameter to be the same as randomly-chosen param
fixed = dic;
# generate
for f in self.fList:
fName = os.path.basename(f)
params = fName.split('_')
# fixed params
args, temp = self.parser.parse_known_args(params)
dic = vars(args)
bFind = True
for key in dic:
# in the case of varying param
if key in varying:
continue
# if the fixed param is different,
if key in fixed and dic[key] != fixed[key]:
bFind = False
break
# if the fixed params are the same,
if bFind == True:
res = self.load_file(f)
yield res, args
import argparse
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--batchSize', type=int, default=32, help='mini batch size')
parser.add_argument('--nEpoch', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate, default=0.0002')
parser.add_argument('--batchProb', type=int, default=0, help='batch size to compare (prob=batchSize/batchProb')
parser.add_argument('--useBN', type=int, default=0, help='use batch normalization')
parser.add_argument('--useDropout', type=int, default=0, help='use dropout')
parser.add_argument('--opt', type=str, default='sgd', help='optimizers: sgd, rsgd (ours) (default:sgd)')
fName = '--lr_0.01_--batchSize_32_--batchProb_0_--useBN_0_--opt_sgd_--useDropout_0_--nEpoch_10_.txt'
print(fName.split('_'))
parser.parse_known_args(fName.split('_'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment