Last active
June 5, 2017 07:26
-
-
Save jakeoung/79b72bdeed8b08d97462ca6886991623 to your computer and use it in GitHub Desktop.
jkplot library and python argparse usage of making file name and loading
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ########## main.py | |
| import argparse | |
| import time | |
| parser = argparse.ArgumentParser(add_help=False) | |
| parser.add_argument('--niter', type=int, default=100, help='number of iteration') | |
| args, unparsed = parser.parse_known_args() | |
| # make output file name containing the key parameters | |
| fResult = time.strftime('%m%d_') | |
| dic = vars(args) | |
| key_params = dic.keys() | |
| for key in sorted(key_params): | |
| fResult += '--' + str(key) + '_' + str(dic[key]) + '_' | |
| parser = argparse.ArgumentParser(parents=[parser], conflict_handler='resolve') | |
| parser.add_argument('--dResult', type=str, default='', help='directory path to save the result') | |
| parser.add_argument('-v', '--verbose', action='count', help='enable verbose mode') | |
| fResult_ = os.path.join(args.d, fResult) | |
| # parser.add_argument('--fileResult', type=str, default='', help='file name to save the results') | |
| # parser.add_argument('--useOurs', action='store_true', help='use ours') | |
| args = parser.parse_args(namespace=args) | |
| print(args, fResult) | |
| ## write result | |
| if args.verbose: | |
| pass | |
| if os.path.exists(args.dResult): | |
| np.save(fResult_, [array1, array2]) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import os | |
| import matplotlib.pyplot as plt | |
| from matplotlib import rc | |
| import pandas as pd | |
| import re | |
| import argparse | |
| import random | |
| def set_matplotlib(): | |
| # Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!) | |
| # rc('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans'],'size':10}) | |
| # # Set the font used for MathJax - more on this later | |
| # rc('mathtext',**{'default':'regular'}) | |
| params = { | |
| 'axes.labelsize': 8, | |
| 'font.size': 8, | |
| 'legend.fontsize': 10, | |
| 'xtick.labelsize': 10, | |
| 'ytick.labelsize': 10, | |
| 'text.usetex': False, | |
| 'figure.figsize': [6., 4.] # instead of 4.5, 4.5 | |
| } | |
| plt.rcParams.update(params) | |
| class JKPlot(object): | |
| """ | |
| Plot utility class for general purpose based on argparse | |
| """ | |
| def __init__(self, dResult, parser, ext='npy', fList=None): | |
| """ | |
| dResult : list of directories containing result files | |
| parser : argparse object | |
| ext : extension to extract | |
| """ | |
| print(type(dResult)) | |
| assert(type(dResult) == list) | |
| self.dResult = dResult | |
| self.parser = parser | |
| if fList == None: | |
| self.fList = [] | |
| for res in dResult: | |
| self.fList += [res+f for f in os.listdir(res) if f.endswith(ext)] | |
| else: | |
| self.fList = fList | |
| def load_file(self, f): | |
| """ | |
| load file f, depending on the saved format | |
| """ | |
| res = np.load(f) | |
| # ... | |
| return res | |
| def remove_file(self, **kwargs): | |
| for f in self.fList: | |
| fName = os.path.basename(f) | |
| params = fName.split('_') | |
| args, temp = self.parser.parse_known_args(params) | |
| dic = vars(args) | |
| bFind = True | |
| for key in kwargs: | |
| if dic[key] != kwargs[key]: | |
| bFind = False | |
| break | |
| if bFind == True: | |
| os.remove(f) | |
| return | |
| def load_result(self, **kwargs): | |
| """ | |
| load one result for given parameter | |
| Usage: | |
| >>> res = load_result(opt='sgd', lr=0.05, batchSize=32, dropout=0) | |
| """ | |
| for f in self.fList: | |
| fName = os.path.basename(f) | |
| params = fName.split('_') | |
| args, temp = self.parser.parse_known_args(params) | |
| dic = vars(args) | |
| bFind = True | |
| for key in kwargs: | |
| if dic[key] != kwargs[key]: | |
| bFind = False | |
| break | |
| if bFind == True: | |
| print(args) | |
| res = np.load(f) | |
| return res, args | |
| def generate(self, **kwargs): | |
| """ | |
| generate results of having parameters represented in kwargs | |
| """ | |
| for f in self.fList: | |
| fName = os.path.basename(f) | |
| params = fName.split('_') | |
| args, temp = self.parser.parse_known_args(params) | |
| dic = vars(args) | |
| bFind = True | |
| for key in kwargs: | |
| if dic[key] != kwargs[key]: | |
| bFind = False | |
| break | |
| if bFind == True: | |
| res = self.load_file(f) | |
| yield res, args | |
| def varying_params(self, varying, **fixed): | |
| """ | |
| generate results while varying given parameters of kwargs | |
| Args | |
| fixed (dictionary) : fixed parameters | |
| varying (list) : varying parameters | |
| Returns | |
| Yield the results with varying parameters. | |
| Example | |
| >>> for res, args in jp.varying_params(varying='batchProb', opt='rsgd', model='resnet18'): | |
| >>> print(args) | |
| """ | |
| bFind = False | |
| while (bFind == False): | |
| # randomly choose one | |
| r = random.randint(0, len(self.fList)-1) | |
| f = self.fList[r] | |
| fName = os.path.basename(f) | |
| params = fName.split('_') | |
| args, temp = self.parser.parse_known_args(params) | |
| dic = vars(args) | |
| equal_to_fixed = True | |
| for key in fixed: | |
| if fixed[key] != dic[key]: | |
| equal_to_fixed = False | |
| break | |
| if equal_to_fixed == False: | |
| continue | |
| bFind = True | |
| # force the fixed parameter to be the same as randomly-chosen param | |
| fixed = dic; | |
| # generate | |
| for f in self.fList: | |
| fName = os.path.basename(f) | |
| params = fName.split('_') | |
| # fixed params | |
| args, temp = self.parser.parse_known_args(params) | |
| dic = vars(args) | |
| bFind = True | |
| for key in dic: | |
| # in the case of varying param | |
| if key in varying: | |
| continue | |
| # if the fixed param is different, | |
| if key in fixed and dic[key] != fixed[key]: | |
| bFind = False | |
| break | |
| # if the fixed params are the same, | |
| if bFind == True: | |
| res = self.load_file(f) | |
| yield res, args |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| parser = argparse.ArgumentParser(add_help=False) | |
| parser.add_argument('--batchSize', type=int, default=32, help='mini batch size') | |
| parser.add_argument('--nEpoch', type=int, default=10, help='number of epochs to train for') | |
| parser.add_argument('--lr', type=float, default=0.01, help='learning rate, default=0.0002') | |
| parser.add_argument('--batchProb', type=int, default=0, help='batch size to compare (prob=batchSize/batchProb') | |
| parser.add_argument('--useBN', type=int, default=0, help='use batch normalization') | |
| parser.add_argument('--useDropout', type=int, default=0, help='use dropout') | |
| parser.add_argument('--opt', type=str, default='sgd', help='optimizers: sgd, rsgd (ours) (default:sgd)') | |
| fName = '--lr_0.01_--batchSize_32_--batchProb_0_--useBN_0_--opt_sgd_--useDropout_0_--nEpoch_10_.txt' | |
| print(fName.split('_')) | |
| parser.parse_known_args(fName.split('_')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment