Created
October 31, 2025 13:44
-
-
Save izdrail/7e24c9f2935f401e25512dbc562aefa7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Split between train and val folders | |
| from pathlib import Path | |
| import random | |
| import os | |
| import sys | |
| import shutil | |
| import argparse | |
| # Define and parse user input arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--datapath', help='Path to data folder containing image and annotation files', | |
| required=True) | |
| parser.add_argument('--train_pct', help='Ratio of images to go to train folder; \ | |
| the rest go to validation folder (example: ".8")', | |
| default=.8) | |
| args = parser.parse_args() | |
| data_path = args.datapath | |
| train_percent = float(args.train_pct) | |
| # Check for valid entries | |
| if not os.path.isdir(data_path): | |
| print('Directory specified by --datapath not found. Verify the path is correct (and uses double back slashes if on Windows) and try again.') | |
| sys.exit(0) | |
| if train_percent < .01 or train_percent > 0.99: | |
| print('Invalid entry for train_pct. Please enter a number between .01 and .99.') | |
| sys.exit(0) | |
| val_percent = 1 - train_percent | |
| # Define path to input dataset | |
| input_image_path = os.path.join(data_path,'images') | |
| input_label_path = os.path.join(data_path,'labels') | |
| # Define paths to image and annotation folders | |
| cwd = os.getcwd() | |
| train_img_path = os.path.join(cwd,'data/train/images') | |
| train_txt_path = os.path.join(cwd,'data/train/labels') | |
| val_img_path = os.path.join(cwd,'data/validation/images') | |
| val_txt_path = os.path.join(cwd,'data/validation/labels') | |
| # Create folders if they don't already exist | |
| for dir_path in [train_img_path, train_txt_path, val_img_path, val_txt_path]: | |
| if not os.path.exists(dir_path): | |
| os.makedirs(dir_path) | |
| print(f'Created folder at {dir_path}.') | |
| # Get list of all images and annotation files | |
| img_file_list = [path for path in Path(input_image_path).rglob('*')] | |
| txt_file_list = [path for path in Path(input_label_path).rglob('*')] | |
| print(f'Number of image files: {len(img_file_list)}') | |
| print(f'Number of annotation files: {len(txt_file_list)}') | |
| # Determine number of files to move to each folder | |
| file_num = len(img_file_list) | |
| train_num = int(file_num*train_percent) | |
| val_num = file_num - train_num | |
| print('Images moving to train: %d' % train_num) | |
| print('Images moving to validation: %d' % val_num) | |
| # Select files randomly and copy them to train or val folders | |
| for i, set_num in enumerate([train_num, val_num]): | |
| for ii in range(set_num): | |
| img_path = random.choice(img_file_list) | |
| img_fn = img_path.name | |
| base_fn = img_path.stem | |
| txt_fn = base_fn + '.txt' | |
| txt_path = os.path.join(input_label_path,txt_fn) | |
| if i == 0: # Copy first set of files to train folders | |
| new_img_path, new_txt_path = train_img_path, train_txt_path | |
| elif i == 1: # Copy second set of files to the validation folders | |
| new_img_path, new_txt_path = val_img_path, val_txt_path | |
| shutil.copy(img_path, os.path.join(new_img_path,img_fn)) | |
| #os.rename(img_path, os.path.join(new_img_path,img_fn)) | |
| if os.path.exists(txt_path): # If txt path does not exist, this is a background image, so skip txt file | |
| shutil.copy(txt_path,os.path.join(new_txt_path,txt_fn)) | |
| #os.rename(txt_path,os.path.join(new_txt_path,txt_fn)) | |
| img_file_list.remove(img_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment