Skip to content

Instantly share code, notes, and snippets.

@zoezhou1999
Last active August 17, 2020 14:01
Show Gist options
  • Select an option

  • Save zoezhou1999/bf3460f61d0bccbe0dcbe3072f5672a9 to your computer and use it in GitHub Desktop.

Select an option

Save zoezhou1999/bf3460f61d0bccbe0dcbe3072f5672a9 to your computer and use it in GitHub Desktop.
matting
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#import matplotlib.pyplot as plt
import pdb
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable
#import scipy.io as sio
def compute_gradient(img):
gradx=img[:,1:,:,:]-img[:,:-1,:,:]
grady=img[:,:,1:,:]-img[:,:,:-1,:]
return gradx,grady
#multi-task loss
class MultiLoss(_Loss):
def __init__(self):
super(MultiLoss,self).__init__()
self.sigma_A = torch.tensor([0.]).float().cuda().requires_grad_()
self.sigma_B = torch.tensor([0.]).float().cuda().requires_grad_()
def forward(self,lossA,lossB):
return torch.exp(-self.sigma_A)*lossA+torch.exp(-self.sigma_B)*lossB+self.sigma_A+self.sigma_B
class CrossEntropyLoss(_Loss):
def __init__(self):
super(CrossEntropyLoss,self).__init__()
def forward(self,trimap_pred,trimap_gt):
crossEntropyLoss=nn.CrossEntropyLoss()
return crossEntropyLoss(trimap_pred,trimap_gt)
class ExclusionLoss(_Loss):
def __init__(self):
super(ExclusionLoss,self).__init__()
def forward(self,img1,img2,level=3):
gradx_loss=[]
grady_loss=[]
eps=1e-8
for l in range(level):
gradx1, grady1=compute_gradient(img1)
gradx2, grady2=compute_gradient(img2)
alphax=2.0*torch.mean(torch.abs(gradx1))/(torch.mean(torch.abs(gradx2))+eps)
alphay=2.0*torch.mean(torch.abs(grady1))/(torch.mean(torch.abs(grady2))+eps)
gradx1_s=(torch.sigmoid(gradx1)*2)-1
grady1_s=(torch.sigmoid(grady1)*2)-1
gradx2_s=(torch.sigmoid(gradx2*alphax)*2)-1
grady2_s=(torch.sigmoid(grady2*alphay)*2)-1
gradx_loss.append((torch.mean(torch.mul(torch.mul(gradx1_s,gradx1_s),torch.mul(gradx2_s,gradx2_s)),dim=(1,2,3))+eps)**0.25)
grady_loss.append((torch.mean(torch.mul(torch.mul(grady1_s,grady1_s),torch.mul(grady2_s,grady2_s)),dim=(1,2,3))+eps)**0.25)
img1=torch.nn.functional.avg_pool2d(img1, (2,2), (2,2))
img2=torch.nn.functional.avg_pool2d(img2, (2,2), (2,2))
loss_gradxy=torch.sum(sum(gradx_loss)/(1.0*level))+torch.sum(sum(grady_loss)/(1.0*level))
return loss_gradxy
class alpha_loss(_Loss):
def __init__(self):
super(alpha_loss,self).__init__()
def forward(self,alpha,alpha_pred,mask):
# return F.l1_loss(alpha,alpha_pred)
return normalized_l1_loss(alpha,alpha_pred,mask)
class compose_loss(_Loss):
def __init__(self):
super(compose_loss,self).__init__()
def forward(self,image,alpha_pred,fg,bg,mask):
# alpha_pred=(alpha_pred+1)/2
comp=fg*alpha_pred + (1-alpha_pred)*bg
# return F.l1_loss(image,comp)
return normalized_l1_loss(image,comp,mask)
class alpha_gradient_loss(_Loss):
def __init__(self):
super(alpha_gradient_loss,self).__init__()
def forward(self,alpha,alpha_pred,mask):
fx = torch.Tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]]); fx=fx.view((1,1,3,3)); fx=Variable(fx.cuda())
fy = torch.Tensor([[1, 2, 1],[0, 0, 0],[-1, -2, -1]]); fy=fy.view((1,1,3,3)); fy=Variable(fy.cuda())
G_x = F.conv2d(alpha,fx,padding=1); G_y = F.conv2d(alpha,fy,padding=1)
G_x_pred = F.conv2d(alpha_pred,fx,padding=1); G_y_pred = F.conv2d(alpha_pred,fy,padding=1)
loss=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask)
# loss=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask)
return loss
######################################
###TO DO
class detailed_alpha_gradient_loss(_Loss):
def __init__(self):
super(detailed_alpha_gradient_loss,self).__init__()
def forward(self,alpha,alpha_pred,mask):
fx = torch.Tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]]); fx=fx.view((1,1,3,3)); fx=Variable(fx.cuda())
fy = torch.Tensor([[1, 2, 1],[0, 0, 0],[-1, -2, -1]]); fy=fy.view((1,1,3,3)); fy=Variable(fy.cuda())
G_x = F.conv2d(alpha,fx,padding=1); G_y = F.conv2d(alpha,fy,padding=1)
G_x_pred = F.conv2d(alpha_pred,fx,padding=1); G_y_pred = F.conv2d(alpha_pred,fy,padding=1)
loss_grad_mag=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask)
# loss=normalized_l1_loss(G_x,G_x_pred,mask) + normalized_l1_loss(G_y,G_y_pred,mask)
#tf.abs(gt_gradients_y*pd_gradients_x - gt_gradients_x*pd_gradients_y)
loss_grad_dir=normalized_l1_loss(torch.mul(G_y,G_x_pred),torch.mul(G_x,G_y_pred),mask)
loss=loss_grad_mag+loss_grad_dir
# normal_product = 1 + torch.mul(G_y,G_y_pred) + torch.mul(G_x,G_x_pred)
# gt_normal_mag = torch.sqrt(1 + torch.square(G_y) + torch.square(G_x))
# pd_normal_mag = torch.sqrt(1 + torch.square(G_y_pred) + torch.square(G_x_pred))
# loss_normal = tf.reduce_mean(1 - tf.divide(normal_product, gt_normal_mag*pd_normal_mag))
return loss
class alpha_gradient_reg_loss(_Loss):
def __init__(self):
super(alpha_gradient_reg_loss,self).__init__()
def forward(self,alpha,mask):
fx = torch.Tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]]); fx=fx.view((1,1,3,3)); fx=Variable(fx.cuda())
fy = torch.Tensor([[1, 2, 1],[0, 0, 0],[-1, -2, -1]]); fy=fy.view((1,1,3,3)); fy=Variable(fy.cuda())
G_x = F.conv2d(alpha,fx,padding=1); G_y = F.conv2d(alpha,fy,padding=1)
loss=(torch.sum(torch.abs(G_x))+torch.sum(torch.abs(G_y)))/torch.sum(mask)
return loss
class GANloss(_Loss):
def __init__(self):
super(GANloss,self).__init__()
def forward(self,pred,label_type):
MSE=nn.MSELoss()
loss=0
for i in range(0,len(pred)):
if label_type:
labels=torch.ones(pred[i][0].shape)
else:
labels=torch.zeros(pred[i][0].shape)
labels=Variable(labels.cuda())
loss += MSE(pred[i][0],labels)
return loss/len(pred)
def normalized_l1_loss(alpha,alpha_pred,mask):
loss=0; eps=1e-6
for i in range(alpha.shape[0]):
# if mask[i,...].sum()>0:
loss = loss + torch.sum(torch.abs(alpha[i,...]*mask[i,...]-alpha_pred[i,...]*mask[i,...]))/(torch.sum(mask[i,...])+eps)
loss=loss/alpha.shape[0]
return loss
def gauss_kernel(size=5, device=torch.device('cpu'), channels=3):
kernel = torch.tensor([[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]])
kernel /= 256.
kernel = kernel.repeat(channels, 1, 1, 1)
kernel = kernel.to(device)
return kernel
def downsample(x):
return x[:, :, ::2, ::2]
def upsample(x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
cc = cc.permute(0,1,3,2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
x_up = cc.permute(0,1,3,2)
return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device))
def conv_gauss(img, kernel):
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
return out
def laplacian_pyramid(img, kernel, max_levels=5):
current = img
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
down = downsample(filtered)
up = upsample(down)
diff = current-up
pyr.append(diff)
current = down
return pyr
def mask_pyramid(img, max_levels=5):
current = img
pyr = []
pyr.append(current)
for level in range(max_levels):
down = downsample(current)
pyr.append(down)
current = down
return pyr
class LapLoss(_Loss):
def __init__(self, max_levels=5, channels=3, device=torch.device('cuda')):
super(LapLoss, self).__init__()
self.max_levels = max_levels
self.gauss_kernel = gauss_kernel(channels=channels, device=device)
def forward(self, input, target, mask):
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_mask = mask_pyramid(img=mask, max_levels=self.max_levels)
return sum(normalized_l1_loss(a, b, c)*(2**(i)) for i, (a, b, c) in enumerate(zip(pyr_input, pyr_target, pyr_mask)))
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from RAdam.radam import RAdam
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import os
import time
import argparse
import numpy as np
from torchvision import transforms
from dataloader import AdobeCompositeTestData,OurAdobeDataAffineHRV3
from functions import *
from networks.discriminator import MultiscaleDiscriminator,conv_init
from networks.config_v3 import cfg
from loss_functions import alpha_loss, compose_loss, alpha_gradient_loss, GANloss, LapLoss, ExclusionLoss, MultiLoss, CrossEntropyLoss,detailed_alpha_gradient_loss
from networks.transforms import trimap_transform, groupnorm_normalise_image,trimap_transform_batch
from networks.models import fba_fusion, build_modelV3
from dataloader import np_to_torch,scale_input
from evaulate import test_netM,compute_gradient_loss,compute_connectivity_error,compute_mse_loss,compute_sad_loss,compute_ce_loss
from networks import layers_WS
from utils import group_weight, collate_filter_none, adjust_lr_v2, get_params, groupnorm_normalise_image_cuda, trimap_transform_cuda
from networks.sync_batchnorm.replicate import patch_replication_callback
#CUDA
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Training Background Matting on Adobe Dataset.')
parser.add_argument('--encoder', default='resnet50_GN_WS', help="encoder model")
parser.add_argument('--decoder', default='fba_decoder', help="Decoder model")
parser.add_argument('--fba_weights', default='FBA.pth')
parser.add_argument('--has_checkpoint', default=0, type=int)
parser.add_argument('--checkpointM', default='',type=str)
parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders.')
parser.add_argument('-bs', '--batch_size', type=int, default=1, help='Batch Size.')
parser.add_argument('-res', '--reso', default=416, type=int, help='Input image resolution')
parser.add_argument('-epoch', '--epoch', type=int, default=45, help='Maximum Epoch')
parser.add_argument('-g', '--gpus', type=str, default="2", help='gpu id')
parser.add_argument('-p', '--prefix', type=str, default="./", help='gpu id')
parser.add_argument('--use_ab_exclusion_loss', type=int, default=0, help='use gradient loss')
parser.add_argument('--use_gradient_loss', type=int, default=0, help='use gradient loss')
parser.add_argument('--gradient_loss_detailed', type=int, default=0, help='use gradient loss')
#when training, to reduce the inference time, we do not use the full image to get the result
parser.add_argument('--use_full_image', type=int, default=0, help='use full image')
args=parser.parse_args()
# torch.autograd.set_detect_anomaly(True)
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpus
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
device = 'cuda:'+args.gpus.split(",")[0] if torch.cuda.is_available() else 'cpu'
# sfd_detector = FaceDetector(device=device, path_to_detector='/kaggle/input/s3fdface/s3fd/s3fd-619a316812.pth', verbose=False)
torch.cuda.set_device(device)
os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting"),exist_ok=True)
os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting","TB_Summary"),exist_ok=True)
os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting","Models"),exist_ok=True)
os.makedirs(os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages"),exist_ok=True)
##Directories
tb_dir=os.path.join(args.prefix,"FBA_GAN_Matting","TB_Summary/"+ args.name)
model_dir=os.path.join(args.prefix,"FBA_GAN_Matting","Models/"+ args.name)
img_dir=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages/"+ args.name)
#all these with fba fusion
if args.use_full_image==1:
img_dir_gt=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"gt_full")
img_dir_netT=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"netT_full")
else:
img_dir_gt=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"gt_crop")
img_dir_netT=os.path.join(args.prefix,"FBA_GAN_Matting","OutputImages",args.name,"netT_crop")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(tb_dir):
os.makedirs(tb_dir)
if not os.path.exists(img_dir):
os.makedirs(img_dir)
os.makedirs(img_dir_gt)
os.makedirs(img_dir_netT)
## Input list
#To Be Adjusted
data_config_train = {'reso': (args.reso,args.reso), 'trimapK': [5,5]} #if trimap is true, rcnn is used
data_config_test = {'reso': (320,320), 'trimapK': [5,5], 'noise': False} # choice for data loading parameters
# DATA LOADING
print('\n[Phase 1] : Data Preparation')
#Original Data
train_data_transfrom = transforms.Compose([
transforms.ToPILImage(),
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
transforms.ToTensor()
])
traindata = OurAdobeDataAffineHRV3(csv_file='Data_adobe/Adobe_train_data.csv',data_config=data_config_train,transform=train_data_transfrom,prefix=args.prefix) #Write a dataloader function that can read the database provided by .csv file
train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none,drop_last=True)
testdata=AdobeCompositeTestData(csv_file="Data_adobe/Adobe_test_data.csv",data_config=data_config_test, transform=None,prefix=args.prefix,use_full_image=args.use_full_image)
test_loader=iter(testdata)
print('\n[Phase 2] : Initialization')
netM=build_modelV3(args.encoder, args.decoder, 'default')
gpu_ids=[int(gpu_id) for gpu_id in args.gpus.split(",")]
if len(args.gpus.split(","))>1:
netM=nn.DataParallel(netM,device_ids=gpu_ids)
patch_replication_callback(netM)
torch.backends.cudnn.benchmark=True
#Loss
l1_loss=alpha_loss()
c_loss=compose_loss()
if args.gradient_loss_detailed==1:
g_loss=detailed_alpha_gradient_loss()
else:
g_loss=alpha_gradient_loss()
lap_loss=LapLoss()
excl_loss=ExclusionLoss()
group_params=group_weight(netM, 1e-5, 1e-5, 0.0005)
optimizerM=RAdam(group_params)
schedulerM = torch.optim.lr_scheduler.MultiStepLR(optimizerM, [40], gamma=0.1)
start_epoch=0
start_loss=0
#default epoch of netT and netM is the same!
if args.has_checkpoint!=0:
checkpointM = torch.load(args.checkpointM)
netM.load_state_dict(checkpointM['model_state_dict'])
optimizerM.load_state_dict(checkpointM['optimizer_state_dict'])
start_epoch = checkpointM['epoch']+1
start_loss = checkpointM['loss']
print("load checkpoints successfully: epoch {}; loss {}".format(start_epoch-1,start_loss))
log_writer=SummaryWriter(tb_dir)
with open(os.path.join(model_dir,'config.txt'),'w') as f:
import json
f.write(json.dumps(cfg.__dict__)+"\n")
f.write("encoder: "+ args.encoder+"\n")
f.write("decoder: "+ args.decoder+"\n")
f.write("fba_weights: "+ args.fba_weights+"\n")
f.write("has_checkpoint: "+ str(args.has_checkpoint)+"\n")
f.write("checkpointM: "+ args.checkpointM+"\n")
f.write("name: "+ args.name+"\n")
f.write("batch_size: "+ str(args.batch_size)+"\n")
f.write("reso: ("+ str(args.reso)+","+str(args.reso)+")\n")
f.write("epoch: "+ str(args.epoch)+"\n")
f.write("gpus: "+ args.gpus+"\n")
f.write("prefix: "+ args.prefix+"\n")
f.write("use_ab_exclusion_loss: "+ str(args.use_ab_exclusion_loss)+"\n")
f.write("use_gradient_loss: "+ str(args.use_gradient_loss)+"\n")
f.write("gradient_loss_detailed: "+ str(args.gradient_loss_detailed)+"\n")
f.write("use_full_image: "+ str(args.use_full_image)+"\n")
print('Starting Training')
step=50
KK=len(train_loader)
test_time_start=time.time()
for epoch in range(start_epoch,args.epoch):
test_loader=iter(testdata)
netM.train()
netL = 0
lT, lM =0,0
alL, lapL, gradL, compL, fbL = 0,0,0,0,0
abL=0
elapse_run, elapse=0,0
t0=time.time()
testL=0; ct_tst=0
testLM=0; testLT=0
test_time_start=time.time()
for i,data in enumerate(train_loader):
#Initiating
#image is the same as image_torch
fg, alpha, image, seg, trimap, bg = data['fg'], data['alpha'], data['image'], data['seg'], data['classify_target_trimap'], data['back']
ori_trimap_=data['ori_trimap_']
image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch = data['netm_input_image'], data['netm_input_trimap'], data['image_transformed_torch'], data['trimap_transformed_torch']
fg, alpha, image, seg = Variable(fg.cuda()), Variable(alpha.cuda()), Variable(image.cuda()), Variable(seg.cuda())
ori_trimap_=Variable(ori_trimap_.cuda())
trimap, bg = Variable(trimap.cuda()), Variable(bg.cuda())
image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch = Variable(image_torch.cuda()), Variable(trimap_torch.cuda()), Variable(image_transformed_torch.cuda()), Variable(trimap_transformed_torch.cuda())
tr0=time.time()
fba_pred=netM(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)
#get the fg, bg, alpha. They are all between 0 and 1.
alpha_pred = fba_pred[:,0:1,:,:]
fg_pred = fba_pred[:,1:4,:,:]
bg_pred = fba_pred[:, 4:7,:,:]
#masks
mask0=Variable(torch.ones(alpha.shape).cuda()) #whole image mask
mask=(alpha>0.01).type(torch.cuda.FloatTensor) #alpha mask fg, quite loose
mask1=(seg>0.95).type(torch.cuda.FloatTensor) #seg create fg mask
mask2=(seg<=0.95).type(torch.cuda.FloatTensor) #seg create bg mask
###### matting loss ######
##l1_alpha, l1_c, l1_lap, l1_g, l1_fb, l1_excl_fb, l1_c_fb, l1_lap_fb
##l_fba=l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
l1_alpha=l1_loss(alpha,alpha_pred.float(),mask0) #use the whole image as the mask
#because the bg and bg of image shoud be the same, use mask1, a bit tight fg mask to calculate the loss
# print('pre_l1_c',l1_c)
l1_c=c_loss(image_torch,alpha_pred,fg_pred,bg,mask1)
#print("l1_c",l1_c,l1_c.data)
#use the whole image mask to calculate the lap loss
l1_lap=lap_loss(alpha,alpha_pred,mask0)
if args.use_gradient_loss==1:
l1_g=g_loss(alpha,alpha_pred,mask0) # it seems like gradient will make results worse, so drop it!
if args.use_ab_exclusion_loss==1:
l1_excl_ab=excl_loss(alpha_pred.float()*fg.float(), bg.float(), level=3)
#FB loss
#use the loose fg mask to calcualte the fg loss and use loss bg mask to calculate the bg loss
l1_fb=l1_loss(fg,fg_pred,mask0)+l1_loss(bg,bg_pred,mask2)
#force the network to separate fg and bg
l1_excl_fb=excl_loss(fg_pred, bg_pred, level=3)
#when use the whole predicted image as a composite image, the whole image mask is used to calculate the loss
l1_c_fb=c_loss(image_torch,alpha,fg_pred,bg_pred,mask0)
#use the whole image mask to calculate the lap loss
l1_lap_fb=lap_loss(fg,fg_pred,mask0)+lap_loss(bg,bg_pred,mask0)
if args.use_gradient_loss==1:
if args.use_ab_exclusion_loss==1:
loss=l1_alpha+l1_c+l1_lap+l1_g+l1_excl_ab+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
else:
loss=l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
#matting_loss=l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
else:
if args.use_ab_exclusion_loss==1:
loss=l1_alpha+l1_c+l1_lap+l1_excl_ab+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
else:
loss=l1_alpha+l1_c+l1_lap+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
#matting_loss=l1_alpha+l1_c+l1_lap+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
optimizerM.zero_grad()
loss.backward()
optimizerM.step()
netL+=loss.data
lM += loss.data
alL += l1_alpha.data
lapL += l1_lap.data
if args.use_gradient_loss==1:
gradL += l1_g.data
compL += l1_c.data
if args.use_ab_exclusion_loss==1:
abL+=l1_excl_ab.data
fbL += (l1_fb.data + l1_excl_fb.data + l1_c_fb.data + l1_lap_fb.data)
log_writer.add_scalar('Matting Loss', loss.data, epoch*KK + i + 1)
#l1_alpha+l1_c+l1_lap+l1_g+0.25*(l1_fb + l1_excl_fb + l1_c_fb + l1_lap_fb)
log_writer.add_scalar('Matting Loss: Alpha', l1_alpha.data, epoch*KK + i + 1)
log_writer.add_scalar('Matting Loss: Comp', l1_c.data, epoch*KK + i + 1)
log_writer.add_scalar('Matting Loss: Lap', l1_lap.data, epoch*KK + i + 1)
if args.use_ab_exclusion_loss==1:
log_writer.add_scalar('Matting Loss: Excl AB', l1_excl_ab.data, epoch*KK + i + 1)
if args.use_gradient_loss==1:
log_writer.add_scalar('Matting Loss: Gradient', l1_g.data, epoch*KK + i + 1)
log_writer.add_scalar('Matting Loss: FB', (l1_fb.data + l1_excl_fb.data + l1_c_fb.data + l1_lap_fb.data+l1_g.data), epoch*KK + i + 1)
else:
log_writer.add_scalar('Matting Loss: FB', (l1_fb.data + l1_excl_fb.data + l1_c_fb.data + l1_lap_fb.data), epoch*KK + i + 1)
t1=time.time()
elapse +=t1 -t0
elapse_run += t1-tr0
t0=t1
testL+=loss.data
testLM+=loss.data
ct_tst+=1
if i % step == (step-1):
if args.use_gradient_loss==1:
#have gradient loss
if args.use_ab_exclusion_loss==1:
print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f Gradient-loss: %.4f Excl-AB-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step,lM/step, alL/step, compL/step, lapL/step, gradL/step, abL/step, fbL/step, elapse/step, elapse_run/step))
else:
print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f Gradient-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step,lM/step, alL/step, compL/step, lapL/step, gradL/step, fbL/step, elapse/step, elapse_run/step))
else:
if args.use_ab_exclusion_loss==1:
#do not have gradient loss
print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f Excl-AB-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step, lM/step, alL/step, compL/step, lapL/step, abL/step, fbL/step, elapse/step, elapse_run/step))
else:
print('[%d, %5d] Total-loss: %.4f Matting-loss: %.4f Alpha-loss: %.4f Comp-loss: %.4f Lap-loss: %.4f FB-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step, lM/step, alL/step, compL/step, lapL/step, fbL/step, elapse/step, elapse_run/step))
netL = 0
lT, lM =0,0
alL, lapL, gradL, compL, fbL = 0,0,0,0,0
abL=0
elapse_run, elapse=0,0
write_tb_log(image,'image',log_writer,i)
write_tb_log(seg,'seg',log_writer,i)
write_tb_log(alpha,'alpha',log_writer,i)
write_tb_log(alpha_pred,'alpha_pred',log_writer,i)
#use a quite loose fg mask to mask the fg
write_tb_log(fg*mask,'fg',log_writer,i)
write_tb_log(fg_pred,'fg_pred',log_writer,i)
write_tb_log(bg*mask2,'bg',log_writer,i)
write_tb_log(bg_pred,'bg_pred',log_writer,i)
write_tb_log(ori_trimap_,'trimap_',log_writer,i)
#composition
comp1=fg_pred*alpha_pred + (1-alpha_pred)*bg_pred
write_tb_log(comp1,'composite_pred',log_writer,i)
del comp1
#add more item to delete to clear the memory
del fg,alpha,image,seg,ori_trimap_,trimap,bg,image_torch,trimap_torch
del image_transformed_torch,trimap_transformed_torch
del alpha_pred,fg_pred,bg_pred,fba_pred
# del l1_alpha,l1_c,l1_lap
# del l1_fb,l1_excl_fb,l1_c_fb,l1_lap_fb
# del l1_excl_ab
del mask,mask0,mask1,mask2
# if args.use_gradient_loss==1:
# del l1_g
schedulerM.step()
netM.eval()
mean_mse_gt,mean_sad_gt,mean_sad_gt,mean_conn_gt,mean_gradient_gt=test_netM(test_loader,netM,img_dir_netT,img_dir_gt,epoch,args.use_full_image,log_writer=log_writer,timestamp=(epoch+1)*KK)
log_writer.add_scalar('Test MSE GT Loss', mean_mse_gt, (epoch+1)*KK)
log_writer.add_scalar('Test SAD GT Loss', mean_sad_gt, (epoch+1)*KK)
log_writer.add_scalar('Test CONN GT Error', mean_conn_gt, (epoch+1)*KK)
log_writer.add_scalar('Test Gradient GT Loss', mean_gradient_gt, (epoch+1)*KK)
print("**************************Test Results*****************************")
print('[Epoch %d] MSE-GT-loss: %.4f SAD-GT-loss: %.4f CONN-Error-GT: %.4f Gradient-GT-loss: %.4f' %(epoch + 1, mean_mse_gt, mean_sad_gt, mean_conn_gt, mean_gradient_gt))
torch.save({
'epoch': epoch,
'model_state_dict': netM.state_dict(),
'optimizer_state_dict': optimizerM.state_dict(),
'loss': testL/ct_tst,
}, model_dir + '/netM_models_epoch_%d_%.4f.pth' %(epoch,testL/ct_tst))
test_time_end=time.time()
print("Epoch {} Done".format(epoch),(test_time_end-test_time_start)/60.0,"minutes")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment