Last active
August 17, 2020 14:01
-
-
Save zoezhou1999/bf3460f61d0bccbe0dcbe3072f5672a9 to your computer and use it in GitHub Desktop.
matting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| #import matplotlib.pyplot as plt | |
| import pdb | |
| from torch.nn.modules.loss import _Loss | |
| from torch.autograd import Function, Variable | |
| #import scipy.io as sio | |
| def compute_gradient(img): | |
| gradx=img[:,1:,:,:]-img[:,:-1,:,:] | |
| grady=img[:,:,1:,:]-img[:,:,:-1,:] | |
| return gradx,grady | |
| #multi-task loss | |
| class MultiLoss(_Loss): | |
| def __init__(self): | |
| super(MultiLoss,self).__init__() | |
| self.sigma_A = torch.tensor([0.]).float().cuda().requires_grad_() | |
| self.sigma_B = torch.tensor([0.]).float().cuda().requires_grad_() | |
| def forward(self,lossA,lossB): | |
| return torch.exp(-self.sigma_A)*lossA+torch.exp(-self.sigma_B)*lossB+self.sigma_A+self.sigma_B | |
| class CrossEntropyLoss(_Loss): | |
| def __init__(self): | |
| super(CrossEntropyLoss,self).__init__() | |
| def forward(self,trimap_pred,trimap_gt): | |
| crossEntropyLoss=nn.CrossEntropyLoss() | |
| return crossEntropyLoss(trimap_pred,trimap_gt) | |
| class ExclusionLoss(_Loss): | |
| def __init__(self): | |
| super(ExclusionLoss,self).__init__() | |
| def forward(self,img1,img2,level=3): | |
| gradx_loss=[] | |
| grady_loss=[] | |
| eps=1e-8 | |
| for l in range(level): | |
| gradx1, grady1=compute_gradient(img1) | |
| gradx2, grady2=compute_gradient(img2) | |
| alphax=2.0*torch.mean(torch.abs(gradx1))/(torch.mean(torch.abs(gradx2))+eps) | |
| alphay=2.0*torch.mean(torch.abs(grady1))/(torch.mean(torch.abs(grady2))+eps) | |
| gradx1_s=(torch.sigmoid(gradx1)*2)-1 | |
| grady1_s=(torch.sigmoid(grady1)*2)-1 | |
| gradx2_s=(torch.sigmoid(gradx2*alphax)*2)-1 | |
| grady2_s=(torch.sigmoid(grady2*alphay)*2)-1 | |
| gradx_loss.append((torch.mean(torch.mul(torch.mul(gradx1_s,gradx1_s),torch.mul(gradx2_s,gradx2_s)),dim=(1,2,3))+eps)**0.25) | |
| grady_loss.append((torch.mean(torch.mul(torch.mul(grady1_s,grady1_s),torch.mul(grady2_s,grady2_s)),dim=(1,2,3))+eps)**0.25) | |
| img1=torch.nn.functional.avg_pool2d(img1, (2,2), (2,2)) | |
| img2=torch.nn.functional.avg_pool2d(img2, (2,2), (2,2)) | |
| loss_gradxy=torch.sum(sum(gradx_loss)/(1.0*level))+torch.sum(sum(grady_loss)/(1.0*level)) | |
| return loss_gradxy | |
| class alpha_loss(_Loss): | |
| def __init__(self): | |
| super(alpha_loss,self).__init__() | |
| def forward(self,alpha,alpha_pred,mask): | |
| # return F.l1_loss(alpha,alpha_pred) | |
| return normalized_l1_loss(alpha,alpha_pred,mask) | |
| class compose_loss(_Loss): | |
| def __init__(self): | |
| super(compose_loss,self).__init__() | |
| def forward(self,image,alpha_pred,fg,bg,mask): | |
| # alpha_pred=(alpha_pred+1)/2 | |
| comp=fg*alpha_pred + (1-alpha_pred)*bg | |
| # return F.l1_loss(image,comp) | |
| return normalized_l1_loss(image,comp,mask) | |
| class alpha_gradient_loss(_Loss): | |
| def __init__(self): | |
| super(alpha_gradient_loss,self).__init__() | |
| def forward(self,alpha,alpha_pred,mask): | |
| fx = torch.Tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]]); fx=fx.view((1,1,3,3)); fx=Variable(fx.cuda()) | |
| fy = torch.Tensor([[1, 2, 1],[0, 0, 0],[-1, -2, -1]]); fy=fy.view((1,1,3,3)); fy=Variable(fy.cuda()) | |
| G_x = F.conv2d(alpha,fx,padding=1); G_y = F.conv2d(alpha,fy,padding=1) | |
| G_x_pred = F.conv2d(alpha_pred,fx,padding=1); G_y_pred = F.conv2d(alpha_pred,fy,padding=1) | |
| loss=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask) | |
| # loss=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask) | |
| return loss | |
| ###################################### | |
| ###TO DO | |
| class detailed_alpha_gradient_loss(_Loss): | |
| def __init__(self): | |
| super(detailed_alpha_gradient_loss,self).__init__() | |
| def forward(self,alpha,alpha_pred,mask): | |
| fx = torch.Tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]]); fx=fx.view((1,1,3,3)); fx=Variable(fx.cuda()) | |
| fy = torch.Tensor([[1, 2, 1],[0, 0, 0],[-1, -2, -1]]); fy=fy.view((1,1,3,3)); fy=Variable(fy.cuda()) | |
| G_x = F.conv2d(alpha,fx,padding=1); G_y = F.conv2d(alpha,fy,padding=1) | |
| G_x_pred = F.conv2d(alpha_pred,fx,padding=1); G_y_pred = F.conv2d(alpha_pred,fy,padding=1) | |
| loss_grad_mag=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask) | |
| # loss=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask) | |
| #tf.abs(gt_gradients_y*pd_gradients_x - gt_gradients_x*pd_gradients_y) | |
| loss_grad_dir=normalized_l1_loss(torch.mul(G_y,G_x_pred),torch.mul(G_x,G_y_pred),mask) | |
| loss=loss_grad_mag+loss_grad_dir | |
| # normal_product = 1 + torch.mul(G_y,G_y_pred) + torch.mul(G_x,G_x_pred) | |
| # gt_normal_mag = torch.sqrt(1 + torch.square(G_y) + torch.square(G_x)) | |
| # pd_normal_mag = torch.sqrt(1 + torch.square(G_y_pred) + torch.square(G_x_pred)) | |
| # loss_normal = tf.reduce_mean(1 - tf.divide(normal_product, gt_normal_mag*pd_normal_mag)) | |
| return loss | |
| class alpha_gradient_reg_loss(_Loss): | |
| def __init__(self): | |
| super(alpha_gradient_reg_loss,self).__init__() | |
| def forward(self,alpha,mask): | |
| fx = torch.Tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]]); fx=fx.view((1,1,3,3)); fx=Variable(fx.cuda()) | |
| fy = torch.Tensor([[1, 2, 1],[0, 0, 0],[-1, -2, -1]]); fy=fy.view((1,1,3,3)); fy=Variable(fy.cuda()) | |
| G_x = F.conv2d(alpha,fx,padding=1); G_y = F.conv2d(alpha,fy,padding=1) | |
| loss=(torch.sum(torch.abs(G_x))+torch.sum(torch.abs(G_y)))/torch.sum(mask) | |
| return loss | |
| class GANloss(_Loss): | |
| def __init__(self): | |
| super(GANloss,self).__init__() | |
| def forward(self,pred,label_type): | |
| MSE=nn.MSELoss() | |
| loss=0 | |
| for i in range(0,len(pred)): | |
| if label_type: | |
| labels=torch.ones(pred[i][0].shape) | |
| else: | |
| labels=torch.zeros(pred[i][0].shape) | |
| labels=Variable(labels.cuda()) | |
| loss += MSE(pred[i][0],labels) | |
| return loss/len(pred) | |
| def normalized_l1_loss(alpha,alpha_pred,mask): | |
| loss=0; eps=1e-6 | |
| for i in range(alpha.shape[0]): | |
| # if mask[i,...].sum()>0: | |
| loss = loss + torch.sum(torch.abs(alpha[i,...]*mask[i,...]-alpha_pred[i,...]*mask[i,...]))/(torch.sum(mask[i,...])+eps) | |
| loss=loss/alpha.shape[0] | |
| return loss | |
| def gauss_kernel(size=5, device=torch.device('cpu'), channels=3): | |
| kernel = torch.tensor([[1., 4., 6., 4., 1], | |
| [4., 16., 24., 16., 4.], | |
| [6., 24., 36., 24., 6.], | |
| [4., 16., 24., 16., 4.], | |
| [1., 4., 6., 4., 1.]]) | |
| kernel /= 256. | |
| kernel = kernel.repeat(channels, 1, 1, 1) | |
| kernel = kernel.to(device) | |
| return kernel | |
| def downsample(x): | |
| return x[:, :, ::2, ::2] | |
| def upsample(x): | |
| cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3) | |
| cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) | |
| cc = cc.permute(0,1,3,2) | |
| cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3) | |
| cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) | |
| x_up = cc.permute(0,1,3,2) | |
| return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device)) | |
| def conv_gauss(img, kernel): | |
| img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') | |
| out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) | |
| return out | |
| def laplacian_pyramid(img, kernel, max_levels=5): | |
| current = img | |
| pyr = [] | |
| for level in range(max_levels): | |
| filtered = conv_gauss(current, kernel) | |
| down = downsample(filtered) | |
| up = upsample(down) | |
| diff = current-up | |
| pyr.append(diff) | |
| current = down | |
| return pyr | |
| def mask_pyramid(img, max_levels=5): | |
| current = img | |
| pyr = [] | |
| pyr.append(current) | |
| for level in range(max_levels): | |
| down = downsample(current) | |
| pyr.append(down) | |
| current = down | |
| return pyr | |
| class LapLoss(_Loss): | |
| def __init__(self, max_levels=5, channels=3, device=torch.device('cuda')): | |
| super(LapLoss, self).__init__() | |
| self.max_levels = max_levels | |
| self.gauss_kernel = gauss_kernel(channels=channels, device=device) | |
| def forward(self, input, target, mask): | |
| pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) | |
| pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) | |
| pyr_mask = mask_pyramid(img=mask, max_levels=self.max_levels) | |
| return sum(normalized_l1_loss(a, b, c)*(2**(i)) for i, (a, b, c) in enumerate(zip(pyr_input, pyr_target, pyr_mask))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.autograd import Variable | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from RAdam.radam import RAdam | |
| from tensorboardX import SummaryWriter | |
| import torch.nn.functional as F | |
| import os | |
| import time | |
| import argparse | |
| import numpy as np | |
| from torchvision import transforms | |
| from dataloader import AdobeCompositeTestData,OurAdobeDataAffineHRV3 | |
| from functions import * | |
| from networks.discriminator import MultiscaleDiscriminator,conv_init | |
| from networks.config_v3 import cfg | |
| from loss_functions import alpha_loss, compose_loss, alpha_gradient_loss, GANloss, LapLoss, ExclusionLoss, MultiLoss, CrossEntropyLoss,detailed_alpha_gradient_loss | |
| from networks.transforms import trimap_transform, groupnorm_normalise_image,trimap_transform_batch | |
| from networks.models import fba_fusion, build_modelV3 | |
| from dataloader import np_to_torch,scale_input | |
| from evaulate import test_netM,compute_gradient_loss,compute_connectivity_error,compute_mse_loss,compute_sad_loss,compute_ce_loss | |
| from networks import layers_WS | |
| from utils import group_weight, collate_filter_none, adjust_lr_v2, get_params, groupnorm_normalise_image_cuda, trimap_transform_cuda | |
| from networks.sync_batchnorm.replicate import patch_replication_callback | |
| #CUDA | |
| """Parses arguments.""" | |
| parser = argparse.ArgumentParser(description='Training Background Matting on Adobe Dataset.') | |
| parser.add_argument('--encoder', default='resnet50_GN_WS', help="encoder model") | |
| parser.add_argument('--decoder', default='fba_decoder', help="Decoder model") | |
| parser.add_argument('--fba_weights', default='FBA.pth') | |
| parser.add_argument('--has_checkpoint', default=0, type=int) | |
| parser.add_argument('--checkpointM', default='',type=str) | |
| parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders.') | |
| parser.add_argument('-bs', '--batch_size', type=int, default=1, help='Batch Size.') | |
| parser.add_argument('-res', '--reso', default=416, type=int, help='Input image resolution') | |
| parser.add_argument('-epoch', '--epoch', type=int, default=45, help='Maximum Epoch') | |
| parser.add_argument('-g', '--gpus', type=str, default="2", help='gpu id') | |
| parser.add_argument('-p', '--prefix', type=str, default="./", help='gpu id') | |
| parser.add_argument('--use_ab_exclusion_loss', type=int, default=0, help='use gradient loss') | |
| parser.add_argument('--use_gradient_loss', type=int, default=0, help='use gradient loss') | |
| parser.add_argument('--gradient_loss_detailed', type=int, default=0, help='use gradient loss') | |
| #when training, to reduce the inference time, we do not use the full image to get the result | |
| parser.add_argument('--use_full_image', type=int, default=0, help='use full image') | |
| args=parser.parse_args() | |
| # torch.autograd.set_detect_anomaly(True) | |
| os.environ["CUDA_VISIBLE_DEVICES"]=args.gpus | |
| print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"]) | |
| device = 'cuda:'+args.gpus.split(",")[0] if torch.cuda.is_available() else 'cpu' | |
| # sfd_detector = FaceDetector(device=device, path_to_detector='/kaggle/input/s3fdface/s3fd/s3fd-619a316812.pth', verbose=False) | |
| torch.cuda.set_device(device) | |
| os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting"),exist_ok=True) | |
| os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting","TB_Summary"),exist_ok=True) | |
| os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting","Models"),exist_ok=True) | |
| os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages"),exist_ok=True) | |
| ##Directories | |
| tb_dir=os.path.join(args.prefix,"FBA_GAN_Matting","TB_Summary/"+ args.name) | |
| model_dir=os.path.join(args.prefix,"FBA_GAN_Matting","Models/"+ args.name) | |
| img_dir=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages/"+ args.name) | |
| #all these with fba fusion | |
| if args.use_full_image==1: | |
| img_dir_gt=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"gt_full") | |
| img_dir_netT=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"netT_full") | |
| else: | |
| img_dir_gt=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"gt_crop") | |
| img_dir_netT=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"netT_crop") | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| if not os.path.exists(tb_dir): | |
| os.makedirs(tb_dir) | |
| if not os.path.exists(img_dir): | |
| os.makedirs(img_dir) | |
| os.makedirs(img_dir_gt) | |
| os.makedirs(img_dir_netT) | |
| ## Input list | |
| #To Be Adjusted | |
| data_config_train = {'reso': (args.reso,args.reso), 'trimapK': [5,5]} #if trimap is true, rcnn is used | |
| data_config_test = {'reso': (320,320), 'trimapK': [5,5], 'noise': False} # choice for data loading parameters | |
| # DATA LOADING | |
| print('\n[Phase 1] : Data Preparation') | |
| #Original Data | |
| train_data_transfrom = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), | |
| transforms.ToTensor() | |
| ]) | |
| traindata = OurAdobeDataAffineHRV3(csv_file='Data_adobe/Adobe_train_data.csv',data_config=data_config_train,transform=train_data_transfrom,prefix=args.prefix) #Write a dataloader function that can read the database provided by .csv file | |
| train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none,drop_last=True) | |
| testdata=AdobeCompositeTestData(csv_file="Data_adobe/Adobe_test_data.csv",data_config=data_config_test, transform=None,prefix=args.prefix,use_full_image=args.use_full_image) | |
| test_loader=iter(testdata) | |
| print('\n[Phase 2] : Initialization') | |
| netM=build_modelV3(args.encoder, args.decoder, 'default') | |
| gpu_ids=[int(gpu_id) for gpu_id in args.gpus.split(",")] | |
| if len(args.gpus.split(","))>1: | |
| netM=nn.DataParallel(netM,device_ids=gpu_ids) | |
| patch_replication_callback(netM) | |
| torch.backends.cudnn.benchmark=True | |
| #Loss | |
| l1_loss=alpha_loss() | |
| c_loss=compose_loss() | |
| if args.gradient_loss_detailed==1: | |
| g_loss=detailed_alpha_gradient_loss() | |
| else: | |
| g_loss=alpha_gradient_loss() | |
| lap_loss=LapLoss() | |
| excl_loss=ExclusionLoss() | |
| group_params=group_weight(netM, 1e-5, 1e-5, 0.0005) | |
| optimizerM=RAdam(group_params) | |
| schedulerM = torch.optim.lr_scheduler.MultiStepLR(optimizerM, [40], gamma=0.1) | |
| start_epoch=0 | |
| start_loss=0 | |
| #default epoch of netT and netM is the same! | |
| if args.has_checkpoint!=0: | |
| checkpointM = torch.load(args.checkpointM) | |
| netM.load_state_dict(checkpointM['model_state_dict']) | |
| optimizerM.load_state_dict(checkpointM['optimizer_state_dict']) | |
| start_epoch = checkpointM['epoch']+1 | |
| start_loss = checkpointM['loss'] | |
| print("load checkpoints successfully: epoch {}; loss {}".format(start_epoch-1,start_loss)) | |
| log_writer=SummaryWriter(tb_dir) | |
| with open(os.path.join(model_dir,'config.txt'),'w') as f: | |
| import json | |
| f.write(json.dumps(cfg.__dict__)+"\n") | |
| f.write("encoder: "+ args.encoder+"\n") | |
| f.write("decoder: "+ args.decoder+"\n") | |
| f.write("fba_weights: "+ args.fba_weights+"\n") | |
| f.write("has_checkpoint: "+ str(args.has_checkpoint)+"\n") | |
| f.write("checkpointM: "+ args.checkpointM+"\n") | |
| f.write("name: "+ args.name+"\n") | |
| f.write("batch_size: "+ str(args.batch_size)+"\n") | |
| f.write("reso: ("+ str(args.reso)+","+str(args.reso)+")\n") | |
| f.write("epoch: "+ str(args.epoch)+"\n") | |
| f.write("gpus: "+ args.gpus+"\n") | |
| f.write("prefix: "+ args.prefix+"\n") | |
| f.write("use_ab_exclusion_loss: "+ str(args.use_ab_exclusion_loss)+"\n") | |
| f.write("use_gradient_loss: "+ str(args.use_gradient_loss)+"\n") | |
| f.write("gradient_loss_detailed: "+ str(args.gradient_loss_detailed)+"\n") | |
| f.write("use_full_image: "+ str(args.use_full_image)+"\n") | |
| print('Starting Training') | |
| step=50 | |
| KK=len(train_loader) | |
| test_time_start=time.time() | |
| for epoch in range(start_epoch,args.epoch): | |
| test_loader=iter(testdata) | |
| netM.train() | |
| netL = 0 | |
| lT, lM =0,0 | |
| alL, lapL, gradL, compL, fbL = 0,0,0,0,0 | |
| abL=0 | |
| elapse_run, elapse=0,0 | |
| t0=time.time() | |
| testL=0; ct_tst=0 | |
| testLM=0; testLT=0 | |
| test_time_start=time.time() | |
| for i,data in enumerate(train_loader): | |
| #Initiating | |
| #image is the same as image_torch | |
| fg, alpha, image, seg, trimap, bg = data['fg'], data['alpha'], data['image'], data['seg'], data['classify_target_trimap'], data['back'] | |
| ori_trimap_=data['ori_trimap_'] | |
| image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch = data['netm_input_image'], data['netm_input_trimap'], data['image_transformed_torch'], data['trimap_transformed_torch'] | |
| fg, alpha, image, seg = Variable(fg.cuda()), Variable(alpha.cuda()), Variable(image.cuda()), Variable(seg.cuda()) | |
| ori_trimap_=Variable(ori_trimap_.cuda()) | |
| trimap, bg = Variable(trimap.cuda()), Variable(bg.cuda()) | |
| image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch = Variable(image_torch.cuda()), Variable(trimap_torch.cuda()), Variable(image_transformed_torch.cuda()), Variable(trimap_transformed_torch.cuda()) | |
| tr0=time.time() | |
| fba_pred=netM(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch) | |
| #get the fg, bg, alpha. They are all between 0 and 1. | |
| alpha_pred = fba_pred[:,0:1,:,:] | |
| fg_pred = fba_pred[:,1:4,:,:] | |
| bg_pred = fba_pred[:, 4:7,:,:] | |
| #masks | |
| mask0=Variable(torch.ones(alpha.shape).cuda()) #whole image mask | |
| mask=(alpha>0.01).type(torch.cuda.FloatTensor) #alpha mask fg, quite loose | |
| mask1=(seg>0.95).type(torch.cuda.FloatTensor) #seg create fg mask | |
| mask2=(seg<=0.95).type(torch.cuda.FloatTensor) #seg create bg mask | |
| ###### matting loss ###### | |
| ##l1_alpha, l1_c, l1_lap, l1_g, l1_fb, l1_excl_fb, l1_c_fb, l1_lap_fb | |
| ##l_fba=l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| l1_alpha=l1_loss(alpha,alpha_pred.float(),mask0) #use the whole image as the mask | |
| #because the bg and bg of image shoud be the same, use mask1, a bit tight fg mask to calculate the loss | |
| # print('pre_l1_c',l1_c) | |
| l1_c=c_loss(image_torch,alpha_pred,fg_pred,bg,mask1) | |
| #print("l1_c",l1_c,l1_c.data) | |
| #use the whole image mask to calculate the lap loss | |
| l1_lap=lap_loss(alpha,alpha_pred,mask0) | |
| if args.use_gradient_loss==1: | |
| l1_g=g_loss(alpha,alpha_pred,mask0) # it seems like gradient will make results worse, so drop it! | |
| if args.use_ab_exclusion_loss==1: | |
| l1_excl_ab=excl_loss(alpha_pred.float()*fg.float(), bg.float(), level=3) | |
| #FB loss | |
| #use the loose fg mask to calcualte the fg loss and use loss bg mask to calculate the bg loss | |
| l1_fb=l1_loss(fg,fg_pred,mask0)+l1_loss(bg,bg_pred,mask2) | |
| #force the network to separate fg and bg | |
| l1_excl_fb=excl_loss(fg_pred, bg_pred, level=3) | |
| #when use the whole predicted image as a composite image, the whole image mask is used to calculate the loss | |
| l1_c_fb=c_loss(image_torch,alpha,fg_pred,bg_pred,mask0) | |
| #use the whole image mask to calculate the lap loss | |
| l1_lap_fb=lap_loss(fg,fg_pred,mask0)+lap_loss(bg,bg_pred,mask0) | |
| if args.use_gradient_loss==1: | |
| if args.use_ab_exclusion_loss==1: | |
| loss=l1_alpha+l1_c+l1_lap+l1_g+l1_excl_ab+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| else: | |
| loss=l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| #matting_loss=l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| else: | |
| if args.use_ab_exclusion_loss==1: | |
| loss=l1_alpha+l1_c+l1_lap+l1_excl_ab+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| else: | |
| loss=l1_alpha+l1_c+l1_lap+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| #matting_loss=l1_alpha+l1_c+l1_lap+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| optimizerM.zero_grad() | |
| loss.backward() | |
| optimizerM.step() | |
| netL+=loss.data | |
| lM += loss.data | |
| alL += l1_alpha.data | |
| lapL += l1_lap.data | |
| if args.use_gradient_loss==1: | |
| gradL += l1_g.data | |
| compL += l1_c.data | |
| if args.use_ab_exclusion_loss==1: | |
| abL+=l1_excl_ab.data | |
| fbL += (l1_fb.data + l1_excl_fb.data + l1_c_fb.data + l1_lap_fb.data) | |
| log_writer.add_scalar('Matting Loss', loss.data, epoch*KK + i + 1) | |
| #l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb) | |
| log_writer.add_scalar('Matting Loss: Alpha', l1_alpha.data, epoch*KK + i + 1) | |
| log_writer.add_scalar('Matting Loss: Comp', l1_c.data, epoch*KK + i + 1) | |
| log_writer.add_scalar('Matting Loss: Lap', l1_lap.data, epoch*KK + i + 1) | |
| if args.use_ab_exclusion_loss==1: | |
| log_writer.add_scalar('Matting Loss: Excl AB', l1_excl_ab.data, epoch*KK + i + 1) | |
| if args.use_gradient_loss==1: | |
| log_writer.add_scalar('Matting Loss: Gradient', l1_g.data, epoch*KK + i + 1) | |
| log_writer.add_scalar('Matting Loss: FB', (l1_fb.data + l1_excl_fb.data + l1_c_fb.data + l1_lap_fb.data+l1_g.data), epoch*KK + i + 1) | |
| else: | |
| log_writer.add_scalar('Matting Loss: FB', (l1_fb.data + l1_excl_fb.data + l1_c_fb.data + l1_lap_fb.data), epoch*KK + i + 1) | |
| t1=time.time() | |
| elapse +=t1 -t0 | |
| elapse_run += t1-tr0 | |
| t0=t1 | |
| testL+=loss.data | |
| testLM+=loss.data | |
| ct_tst+=1 | |
| if i % step == (step-1): | |
| if args.use_gradient_loss==1: | |
| #have gradient loss | |
| if args.use_ab_exclusion_loss==1: | |
| print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f Gradient-loss: %.4f Excl-AB-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step,lM/step, alL/step, compL/step, lapL/step, gradL/step, abL/step, fbL/step, elapse/step, elapse_run/step)) | |
| else: | |
| print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f Gradient-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step,lM/step, alL/step, compL/step, lapL/step, gradL/step, fbL/step, elapse/step, elapse_run/step)) | |
| else: | |
| if args.use_ab_exclusion_loss==1: | |
| #do not have gradient loss | |
| print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f Excl-AB-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step, lM/step, alL/step, compL/step, lapL/step, abL/step, fbL/step, elapse/step, elapse_run/step)) | |
| else: | |
| print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step, lM/step, alL/step, compL/step, lapL/step, fbL/step, elapse/step, elapse_run/step)) | |
| netL = 0 | |
| lT, lM =0,0 | |
| alL, lapL, gradL, compL, fbL = 0,0,0,0,0 | |
| abL=0 | |
| elapse_run, elapse=0,0 | |
| write_tb_log(image,'image',log_writer,i) | |
| write_tb_log(seg,'seg',log_writer,i) | |
| write_tb_log(alpha,'alpha',log_writer,i) | |
| write_tb_log(alpha_pred,'alpha_pred',log_writer,i) | |
| #use a quite loose fg mask to mask the fg | |
| write_tb_log(fg*mask,'fg',log_writer,i) | |
| write_tb_log(fg_pred,'fg_pred',log_writer,i) | |
| write_tb_log(bg*mask2,'bg',log_writer,i) | |
| write_tb_log(bg_pred,'bg_pred',log_writer,i) | |
| write_tb_log(ori_trimap_,'trimap_',log_writer,i) | |
| #composition | |
| comp1=fg_pred*alpha_pred + (1-alpha_pred)*bg_pred | |
| write_tb_log(comp1,'composite_pred',log_writer,i) | |
| del comp1 | |
| #add more item to delete to clear the memory | |
| del fg,alpha,image,seg,ori_trimap_,trimap,bg,image_torch,trimap_torch | |
| del image_transformed_torch,trimap_transformed_torch | |
| del alpha_pred,fg_pred,bg_pred,fba_pred | |
| # del l1_alpha,l1_c,l1_lap | |
| # del l1_fb,l1_excl_fb,l1_c_fb,l1_lap_fb | |
| # del l1_excl_ab | |
| del mask,mask0,mask1,mask2 | |
| # if args.use_gradient_loss==1: | |
| # del l1_g | |
| schedulerM.step() | |
| netM.eval() | |
| mean_mse_gt,mean_sad_gt,mean_sad_gt,mean_conn_gt,mean_gradient_gt=test_netM(test_loader,netM,img_dir_netT,img_dir_gt,epoch,args.use_full_image,log_writer=log_writer,timestamp=(epoch+1)*KK) | |
| log_writer.add_scalar('Test MSE GT Loss', mean_mse_gt, (epoch+1)*KK) | |
| log_writer.add_scalar('Test SAD GT Loss', mean_sad_gt, (epoch+1)*KK) | |
| log_writer.add_scalar('Test CONN GT Error', mean_conn_gt, (epoch+1)*KK) | |
| log_writer.add_scalar('Test Gradient GT Loss', mean_gradient_gt, (epoch+1)*KK) | |
| print("**************************Test Results*****************************") | |
| print('[Epoch %d] MSE-GT-loss: %.4f SAD-GT-loss: %.4f CONN-Error-GT: %.4f Gradient-GT-loss: %.4f' %(epoch + 1, mean_mse_gt, mean_sad_gt, mean_conn_gt, mean_gradient_gt)) | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': netM.state_dict(), | |
| 'optimizer_state_dict': optimizerM.state_dict(), | |
| 'loss': testL/ct_tst, | |
| }, model_dir + '/netM_models_epoch_%d_%.4f.pth' %(epoch,testL/ct_tst)) | |
| test_time_end=time.time() | |
| print("Epoch {} Done".format(epoch),(test_time_end-test_time_start)/60.0,"minutes") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment