This article is mainly about my analysis of basnet source code. If you get any problems,you can ask me on the GitHub.
1. data_loader
from __future__ import print_function, division #在版本2中使用3中功能
import glob #用于文件操作的模块
import torch #包含多维张量的数据结构及运算
from skimage import io, transform, color #数字图像处理包
import numpy as np #用于数学数组,矩阵计算
import numpy as np
import matplotlib.pyplot as plt #绘图
from torch.utils.data import Dataset, DataLoader #加载数据集
from torchvision import transforms, utils #对PIL.Image进行变换
from PIL import Image #图像处理
1.1 调整图片大小
1.1.1 缩放到输出大小——正方形
class RescaleT(object):
def __init__(self,output_size):
assert isinstance(output_size,(int,tuple))
#判断output_size与元组展平列表是否类型相同,若不相同抛出异常
self.output_size = output_size
def __call__(self,sample):
image, label = sample['image'],sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size,int):
if h > w:
new_h, new_w = self.output_size*h/w,self.output_size
else:
new_h, new_w = self.output_size,self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
return {'image':img,'label':lbl}
1.1.2 按比例缩放到输出大小
class Rescale(object):#将图像按比例缩放到输出大小(较小边为output_size)
def __init__(self,output_size):
assert isinstance(output_size,(int,tuple))
self.output_size = output_size
def __call__(self,sample):
image, label = sample['image'],sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size,int):
if h > w:
new_h, new_w = self.output_size*h/w,self.output_size
else:
new_h, new_w = self.output_size,self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image,(new_h,new_w),mode='constant')
lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
return {'image':img,'label':lbl}
1.1.3 从中心剪裁成输出大小
class CenterCrop(object):
def __init__(self,output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self,sample):
image, label = sample['image'], sample['label']
h, w = image.shape[:2]
new_h, new_w = self.output_size
# print("h: %d, w: %d, new_h: %d, new_w: %d"%(h, w, new_h, new_w))
assert((h >= new_h) and (w >= new_w))
h_offset = int(math.floor((h - new_h)/2))
w_offset = int(math.floor((w - new_w)/2))
image = image[h_offset: h_offset + new_h, w_offset: w_offset + new_w] #从中心裁剪
label = label[h_offset: h_offset + new_h, w_offset: w_offset + new_w]
return {'image': image, 'label': label}
1.1.4 随机剪裁成输出大小
class RandomCrop(object):
def __init__(self,output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self,sample):
image, label = sample['image'], sample['label']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
label = label[top: top + new_h, left: left + new_w]
return {'image': image, 'label': label}
1.2 将n维数组转换成张量
class ToTensor(object):#
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, label = sample['image'], sample['label']
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
#zeros返回来一个给定形状和类型的用0填充的数组,是x行,y列,z层的
tmpLbl = np.zeros(label.shape)
image = image/np.max(image) #将范围调至0-1
if(np.max(label)<1e-6):
label = label
else:
label = label/np.max(label)
if image.shape[2]==1:#单色度
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
else:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
tmpLbl[:,:,0] = label[:,:,0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) mean和std表示图像集每个通道的均值和均方差序列,训练imagenet数据时通常设置为这些数
tmpImg = tmpImg.transpose((2, 0, 1))#类似矩阵转置
tmpLbl = label.transpose((2, 0, 1))
return {'image': torch.from_numpy(tmpImg),
'label': torch.from_numpy(tmpLbl)}
#将n维数组转换成张量,返回的张量和ndarray共享同一内存。对张量的修改将反映在ndarray中,反之亦然。返回的张量是不能调整大小的。
class ToTensorLab(object):#与上一个相同,但考虑到了颜色空间
"""Convert ndarrays in sample to Tensors."""
def __init__(self,flag=0):
self.flag = flag
def __call__(self, sample):
image, label = sample['image'], sample['label']
tmpLbl = np.zeros(label.shape)
if(np.max(label)<1e-6):
label = label
else:
label = label/np.max(label)
# change the color space从RGB颜色空间转为LAB颜色空间,它是用数字化的方法来描述人的视觉感应。
#Lab颜色空间中的L分量用于表示像素的亮度,取值范围是[0,100],表示从纯黑到纯白;a表示从红色到绿色的范围,取值范围是[127,-128];b表示从黄色到蓝色的范围,取值范围是[127,-128]。
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0],image.shape[1],6))
tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
if image.shape[2]==1:
tmpImgt[:,:,0] = image[:,:,0]
tmpImgt[:,:,1] = image[:,:,0]
tmpImgt[:,:,2] = image[:,:,0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)#将RGB三元组转换成lab三元组
# nomalize image to range [0,1]
tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))#3个通道是rgb
tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))#3个通道是lab
tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])#与均值的差/均方差
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
elif self.flag == 1: #with Lab color
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
if image.shape[2]==1:
tmpImg[:,:,0] = image[:,:,0]
tmpImg[:,:,1] = image[:,:,0]
tmpImg[:,:,2] = image[:,:,0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
else: # with rgb color
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
image = image/np.max(image)
if image.shape[2]==1:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
else:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
tmpLbl[:,:,0] = label[:,:,0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'image': torch.from_numpy(tmpImg),
'label': torch.from_numpy(tmpLbl)}
1.3 读取照片转换成数组
class SalObjDataset(Dataset):
def __init__(self,img_name_list,lbl_name_list,transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self,idx):
#image=Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
#label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
image = io.imread(self.image_name_list[idx])
#用于读取图片文件。io.imread读出图片格式是uint8(unsigned int);value是numpy array。
if(0==len(self.label_name_list)):#读入标签,考虑了标签是否存在情况
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
#print("len of label3")
#print(len(label_3.shape))
#print(label_3.shape)
label = np.zeros(label_3.shape[0:2])#初始化标签数组值
if(3==len(label_3.shape)):
label = label_3[:,:,0]
elif(2==len(label_3.shape)):
label = label_3
if(3==len(image.shape) and 2==len(label.shape)):#将照片和标签都整成3维的
label = label[:,:,np.newaxis]
elif(2==len(image.shape) and 2==len(label.shape)):
image = image[:,:,np.newaxis]
label = label[:,:,np.newaxis]
# #vertical flipping垂直翻转
# # fliph = np.random.randn(1)
# flipv = np.random.randn(1)
#randn函数返回一个或一组样本,具有标准正态分布。dn表示每个维度,返回值为指定维度的array
# if flipv>0:
# image = image[::-1,:,:]#表示将图像向下翻转180°
# label = label[::-1,:,:]
# #vertical flip
sample = {'image':image, 'label':label}
if self.transform:
sample = self.transform(sample)
return sample
2. basnet_train
import torch
import torchvision #torchvision包含一些常用的数据集、模型、转换函数等
from torch.autograd import Variable
#动态变化求梯度,可逐渐生成计算图,将计算节点连接起来,进行误差反向传递
import torch.nn as nn #用于创建和训练神经网络
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader#加载数据集
from torchvision import transforms, utils#对PIL.Image进行变换
import torch.optim as optim#实现了各种优化算法的库
import torchvision.transforms as standard_transforms
import numpy as np
import glob
from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import CenterCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import BASNet
import pytorch_ssim
import pytorch_iou
2.1 损失函数
bce_loss = nn.BCELoss(size_average=True)
#计算目标值和预测值之间的二进制交叉熵损失函数。返回loss均值
ssim_loss = pytorch_ssim.SSIM(window_size=11,size_average=True)
iou_loss = pytorch_iou.IOU(size_average=True)
def bce_ssim_loss(pred,target):
bce_out = bce_loss(pred,target)
ssim_out = 1 - ssim_loss(pred,target)
iou_out = iou_loss(pred,target)
loss = bce_out + ssim_out + iou_out
return loss
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v):
loss0 = bce_ssim_loss(d0,labels_v)
loss1 = bce_ssim_loss(d1,labels_v)
loss2 = bce_ssim_loss(d2,labels_v)
loss3 = bce_ssim_loss(d3,labels_v)
loss4 = bce_ssim_loss(d4,labels_v)
loss5 = bce_ssim_loss(d5,labels_v)
loss6 = bce_ssim_loss(d6,labels_v)
loss7 = bce_ssim_loss(d7,labels_v)
#ssim0 = 1 - ssim_loss(d0,labels_v)
# iou0 = iou_loss(d0,labels_v)
#loss = torch.pow(torch.mean(torch.abs(labels_v-d0)),2)*(5.0*loss0 + loss1 + loss2 + loss3 + loss4 + loss5) #+ 5.0*lossa
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7#+ 5.0*lossa
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0]))
# print("BCE: l1:%3f, l2:%3f, l3:%3f, l4:%3f, l5:%3f, la:%3f, all:%3f\n"%(loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],lossa.data[0],loss.data[0]))
return loss0, loss
2.2 设置训练数据集的目录
data_dir = './train_data/'
tra_image_dir = 'DUTS/DUTS-TR/DUTS-TR/im_aug/'
tra_label_dir = 'DUTS/DUTS-TR/DUTS-TR/gt_aug/'
image_ext = '.jpg'
label_ext = '.png'
model_dir = "./saved_models/basnet_bsi/"
epoch_num = 100000#训练几遍
batch_size_train = 8 #1次训练使用的样本量
batch_size_val = 1
train_num = 0
val_num = 0
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)#返回所有匹配的文件路径列表'./train_data/DUTS/DUTS-TR/DUTS-TR/im_aug/*.jpg'
tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split("/")[-1]# 以‘/ ’为分割符,保留最后一段*.jpg
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
train_num = len(tra_img_name_list)
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([ #这个类的主要作用是串联多个图片变换的操作
RescaleT(256), #将图像缩放到输出256*256
RandomCrop(224),#随机剪裁成224*224
ToTensorLab(flag=0)]))#将n维数组转换成张量(数值在0-1之间,且是brg),flag=0是只有rgb空间
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
#shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)
2.3 定义模型
# define the net
net = BASNet(3, 1)
if torch.cuda.is_available():
net.cuda()
2.4 优化
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# 待优化参数或参数组,构建好神经网络后,网络的参数都保存在parameters()函数当中
#lr (float, 可选) – 学习率(默认:1e-3),lr:同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
#betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)beta1:一阶矩估计的指数衰减率(如 0.9)。beta2:二阶矩估计的指数衰减率(如 0.999)。该超参数在稀疏梯度(如在 NLP 或计算机视觉任务中)中应该设置为接近 1 的数。
#eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)
#weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
2.5 训练过程
print("---start training...")
ite_num = 0 #从开始到训练结束的迭代次数
running_loss = 0.0 #每2000次迭代累计的7个维度损失相加
running_tar_loss = 0.0 #每2000次迭代累计的最后一层的损失相加
ite_num4val = 0 #每训练一轮的迭代次数
for epoch in range(0, epoch_num): #训练10w轮
net.train()
#net.eval()在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,神经网络每一次生成的结果也是不固定的,如droupout层和BN层,生成质量可能好也可能不好。
for i, data in enumerate(salobj_dataloader): #一次读8张
#enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
inputs, labels = data['image'], data['label']
inputs = inputs.type(torch.FloatTensor)#张量转换成浮点型
labels = labels.type(torch.FloatTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# y zero the parameter gradients 优化置0
optimizer.zero_grad()
# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v) #net应该是个4维的,得到前向反馈结果
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v)
loss.backward() # 反向传播,计算当前梯度
optimizer.step() #根据梯度更新网络参数
# # print statistics
running_loss += loss.data[0] #每2000次迭代累计的7维度混和损失
running_tar_loss += loss2.data[0] #每2000次迭代最后一层(输出层)累计的混合损失
# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, d7, loss2, loss
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
if ite_num % 2000 == 0: # save model every 2000 iterations
#每2000次存一次模型
torch.save(net.state_dict(), model_dir + "basnet_bsi_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
running_loss = 0.0
running_tar_loss = 0.0
net.train() # resume train
ite_num4val = 0
print('-------------Congratulations! Training Done!!!-------------')
3. basnet_test
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from data_loader import RescaleT
from data_loader import CenterCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import BASNet
3.1 标准化
def normPRED(d):#标准化
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
3.2 保存图片
def save_output(image_name,pred,d_dir):#存图片
predict = pred
predict = predict.squeeze()#主要对数据的维度进行压缩,去掉维数为1的的维度
predict_np = predict.cpu().data.numpy()
#.data是读取Variable中的tensor .cpu是把数据转移到cpu上 .numpy()把tensor变成numpy
im = Image.fromarray(predict_np*255).convert('RGB') #array转成image
img_name = image_name.split("/")[-1]
image = io.imread(image_name) #用于读取图片文件。value是numpy array
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
#双线性插值,是一种通过平均周围像素颜色来添加像素的方法,输出的图像的每个像素都是原图中四个像素(2x2)运算的结果’
#由于它是从原图四个像素中运算的,因此这种算法很大程度上消除了锯齿现象,而且效果也比较好。
pb_np = np.array(imo)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir+imidx+'.png')
3.3 获得路径
# --------- 1. get image path and name ---------
image_dir = './test_data/test_images/'
prediction_dir = './test_data/test_results/'
model_dir = './saved_models/basnet_bsi/basnet.pth'
img_name_list = glob.glob(image_dir + '*.jpg')
3.4 数据加载
# --------- 2. dataloader ---------
#1. dataload
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],transform=transforms.Compose([RescaleT(256),ToTensorLab(flag=0)]))
#调成256*256,转成张量?数组
test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1,shuffle=False,num_workers=1)
3.5 模型定义
# --------- 3. model define ---------
print("...load BASNet...")
net = BASNet(3,1)
net.load_state_dict(torch.load(model_dir))#加载模型的参数
if torch.cuda.is_available():
net.cuda()
net.eval()
3.6 推断每张图
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):
print("inferencing:",img_name_list[i_test].split("/")[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7,d8 = net(inputs_test)
# normalization
pred = d1[:,0,:,:]#最后一个阶段stage
pred = normPRED(pred)
# save results to test_results folder
save_output(img_name_list[i_test],pred,prediction_dir)
del d1,d2,d3,d4,d5,d6,d7,d8
4. resnet_model
## code from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torchvision
# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
# 'resnet152', 'ResNet34P','ResNet50S','ResNet50P','ResNet101P']
#
# resnet18_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet18-5c106cde.pth'
# resnet34_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet34-333f7ec4.pth'
# resnet50_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet50-19c8e357.pth'
# resnet101_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet101-5d3b4d8f.pth'
#
# model_urls = {
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
# }
4.1 生成卷积层
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)#bias为偏置值,卷积之后,如果要接BN操作,最好是不设置偏置,因为不起作用,而且占显卡内存。
4.2 定义基础块
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)#以通道为单位批量标准化
self.relu = nn.ReLU(inplace=True)#inplace = True ,会改变输入数据的值,节省反复申请与释放内存的空间与时间,只是将原来的地址传递,效率更好
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x #x是什么
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class BasicBlockDe(nn.Module):#与上一个一样,但多一层卷积
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlockDe, self).__init__()
self.convRes = conv3x3(inplanes,planes,stride)
self.bnRes = nn.BatchNorm2d(planes)
self.reluRes = nn.ReLU(inplace=True)
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = self.convRes(x)
residual = self.bnRes(residual)
residual = self.reluRes(residual)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
4.3 瓶颈块
class Bottleneck(nn.Module):#瓶颈块
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
5. BASNet
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as Fs
from .resnet_model import *
5.1 优化模块
class RefUnet(nn.Module):#优化模块
def __init__(self,in_ch,inc_ch):
super(RefUnet, self).__init__()
self.conv0 = nn.Conv2d(in_ch,inc_ch,3,padding=1)
#in_channels:输入的通道数目 【必选】
#out_channels: 输出的通道数目 【必选】
#kernel_size:卷积核的大小,类型为int 或者元组,当卷积是方形的时候,只需要一个整数边长即可,卷积不是方形,要输入一个元组表示 高和宽。【必选】
#stride: 卷积每次滑动的步长为多少,默认是 1 【可选】
#padding: 设置在所有边界增加 值为 0 的边距的大小(也就是在feature map 外围增加几圈 0 ),例如当 padding =1 的时候,如果原来大小为 3 × 3 ,那么之后的大小为 5 × 5 。即在外围加了一圈 0 。【可选】
#dilation:控制卷积核之间的间距【可选】
self.conv1 = nn.Conv2d(inc_ch,64,3,padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2,2,ceil_mode=True)
#kernel_size(int or tuple) - max pooling的窗口大小,
#stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size
#padding(int or tuple, optional) - 输入的每一条边补充0的层数
#dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数
#return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助
#ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作
self.conv2 = nn.Conv2d(64,64,3,padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2,2,ceil_mode=True)
self.conv3 = nn.Conv2d(64,64,3,padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2,2,ceil_mode=True)
self.conv4 = nn.Conv2d(64,64,3,padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2,2,ceil_mode=True)
#####
self.conv5 = nn.Conv2d(64,64,3,padding=1)
self.bn5 = nn.BatchNorm2d(64)
self.relu5 = nn.ReLU(inplace=True)
#####
self.conv_d4 = nn.Conv2d(128,64,3,padding=1)
self.bn_d4 = nn.BatchNorm2d(64)
self.relu_d4 = nn.ReLU(inplace=True)
self.conv_d3 = nn.Conv2d(128,64,3,padding=1)
self.bn_d3 = nn.BatchNorm2d(64)
self.relu_d3 = nn.ReLU(inplace=True)
self.conv_d2 = nn.Conv2d(128,64,3,padding=1)
self.bn_d2 = nn.BatchNorm2d(64)
self.relu_d2 = nn.ReLU(inplace=True)
self.conv_d1 = nn.Conv2d(128,64,3,padding=1)
self.bn_d1 = nn.BatchNorm2d(64)
self.relu_d1 = nn.ReLU(inplace=True)
self.conv_d0 = nn.Conv2d(64,1,3,padding=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
#scale_factor指定输出为输入的多少倍数
def forward(self,x):
hx = x
hx = self.conv0(hx)
hx1 = self.relu1(self.bn1(self.conv1(hx)))#224*224*64
hx = self.pool1(hx1)
hx2 = self.relu2(self.bn2(self.conv2(hx)))#112*112*64
hx = self.pool2(hx2)
hx3 = self.relu3(self.bn3(self.conv3(hx)))#56*56*64
hx = self.pool3(hx3)
hx4 = self.relu4(self.bn4(self.conv4(hx)))#28*28*64
hx = self.pool4(hx4)
hx5 = self.relu5(self.bn5(self.conv5(hx)))#14*14*64 类似于预测模块中的桥
hx = self.upscore2(hx5)
d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx,hx4),1))))#28*28*64
hx = self.upscore2(d4)
#torch.cat是按维数1拼接在一起,即在通道的维度拼接
d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx,hx3),1))))#56*56*64
hx = self.upscore2(d3)
d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx,hx2),1))))#112*112*64
hx = self.upscore2(d2)
d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx,hx1),1))))#224*224*64
residual = self.conv_d0(d1)
return x + residual
5.2 预测模块
class BASNet(nn.Module):
def __init__(self,n_channels,n_classes):
super(BASNet,self).__init__()
resnet = models.resnet34(pretrained=True)
## -------------Encoder--------------
self.inconv = nn.Conv2d(n_channels,64,3,padding=1)
self.inbn = nn.BatchNorm2d(64)
self.inrelu = nn.ReLU(inplace=True)
#stage 1
self.encoder1 = resnet.layer1 #224*224*64
#stage 2
self.encoder2 = resnet.layer2 #112*112*128
#stage 3
self.encoder3 = resnet.layer3 #56*56*256
#stage 4
self.encoder4 = resnet.layer4 #28*28*512
self.pool4 = nn.MaxPool2d(2,2,ceil_mode=True)
#stage 5
self.resb5_1 = BasicBlock(512,512)
self.resb5_2 = BasicBlock(512,512)
self.resb5_3 = BasicBlock(512,512) #14*14*512
self.pool5 = nn.MaxPool2d(2,2,ceil_mode=True)
#stage 6
self.resb6_1 = BasicBlock(512,512)
self.resb6_2 = BasicBlock(512,512)
self.resb6_3 = BasicBlock(512,512) #7*7*512
## -------------Bridge--------------
#stage Bridge
self.convbg_1 = nn.Conv2d(512,512,3,dilation=2, padding=2) #7*7*512
self.bnbg_1 = nn.BatchNorm2d(512)
self.relubg_1 = nn.ReLU(inplace=True)
self.convbg_m = nn.Conv2d(512,512,3,dilation=2, padding=2)
self.bnbg_m = nn.BatchNorm2d(512)
self.relubg_m = nn.ReLU(inplace=True)
self.convbg_2 = nn.Conv2d(512,512,3,dilation=2, padding=2)
self.bnbg_2 = nn.BatchNorm2d(512)
self.relubg_2 = nn.ReLU(inplace=True)
## -------------Decoder--------------
#stage 6d
self.conv6d_1 = nn.Conv2d(1024,512,3,padding=1)
self.bn6d_1 = nn.BatchNorm2d(512)
self.relu6d_1 = nn.ReLU(inplace=True)
self.conv6d_m = nn.Conv2d(512,512,3,dilation=2, padding=2)
self.bn6d_m = nn.BatchNorm2d(512)
self.relu6d_m = nn.ReLU(inplace=True)
self.conv6d_2 = nn.Conv2d(512,512,3,dilation=2, padding=2)
self.bn6d_2 = nn.BatchNorm2d(512)
self.relu6d_2 = nn.ReLU(inplace=True)
#stage 5d
self.conv5d_1 = nn.Conv2d(1024,512,3,padding=1)
self.bn5d_1 = nn.BatchNorm2d(512)
self.relu5d_1 = nn.ReLU(inplace=True)
self.conv5d_m = nn.Conv2d(512,512,3,padding=1)
self.bn5d_m = nn.BatchNorm2d(512)
self.relu5d_m = nn.ReLU(inplace=True)
self.conv5d_2 = nn.Conv2d(512,512,3,padding=1)
self.bn5d_2 = nn.BatchNorm2d(512)
self.relu5d_2 = nn.ReLU(inplace=True)
#stage 4d
self.conv4d_1 = nn.Conv2d(1024,512,3,padding=1)
self.bn4d_1 = nn.BatchNorm2d(512)
self.relu4d_1 = nn.ReLU(inplace=True)
self.conv4d_m = nn.Conv2d(512,512,3,padding=1)
self.bn4d_m = nn.BatchNorm2d(512)
self.relu4d_m = nn.ReLU(inplace=True)
self.conv4d_2 = nn.Conv2d(512,256,3,padding=1)
self.bn4d_2 = nn.BatchNorm2d(256)
self.relu4d_2 = nn.ReLU(inplace=True)
#stage 3d
self.conv3d_1 = nn.Conv2d(512,256,3,padding=1)
self.bn3d_1 = nn.BatchNorm2d(256)
self.relu3d_1 = nn.ReLU(inplace=True)
self.conv3d_m = nn.Conv2d(256,256,3,padding=1)
self.bn3d_m = nn.BatchNorm2d(256)
self.relu3d_m = nn.ReLU(inplace=True)
self.conv3d_2 = nn.Conv2d(256,128,3,padding=1)
self.bn3d_2 = nn.BatchNorm2d(128)
self.relu3d_2 = nn.ReLU(inplace=True)
#stage 2d
self.conv2d_1 = nn.Conv2d(256,128,3,padding=1)
self.bn2d_1 = nn.BatchNorm2d(128)
self.relu2d_1 = nn.ReLU(inplace=True)
self.conv2d_m = nn.Conv2d(128,128,3,padding=1)
self.bn2d_m = nn.BatchNorm2d(128)
self.relu2d_m = nn.ReLU(inplace=True)
self.conv2d_2 = nn.Conv2d(128,64,3,padding=1)
self.bn2d_2 = nn.BatchNorm2d(64)
self.relu2d_2 = nn.ReLU(inplace=True)
#stage 1d
self.conv1d_1 = nn.Conv2d(128,64,3,padding=1)
self.bn1d_1 = nn.BatchNorm2d(64)
self.relu1d_1 = nn.ReLU(inplace=True)
self.conv1d_m = nn.Conv2d(64,64,3,padding=1)
self.bn1d_m = nn.BatchNorm2d(64)
self.relu1d_m = nn.ReLU(inplace=True)
self.conv1d_2 = nn.Conv2d(64,64,3,padding=1)
self.bn1d_2 = nn.BatchNorm2d(64)
self.relu1d_2 = nn.ReLU(inplace=True)
## -------------Bilinear Upsampling--------------
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
## -------------Side Output--------------
self.outconvb = nn.Conv2d(512,1,3,padding=1)
self.outconv6 = nn.Conv2d(512,1,3,padding=1)
self.outconv5 = nn.Conv2d(512,1,3,padding=1)
self.outconv4 = nn.Conv2d(256,1,3,padding=1)
self.outconv3 = nn.Conv2d(128,1,3,padding=1)
self.outconv2 = nn.Conv2d(64,1,3,padding=1)
self.outconv1 = nn.Conv2d(64,1,3,padding=1)
## -------------Refine Module-------------
self.refunet = RefUnet(1,64)
def forward(self,x):
hx = x
## -------------Encoder-------------
hx = self.inconv(hx)#输入层
hx = self.inbn(hx)
hx = self.inrelu(hx)
h1 = self.encoder1(hx) #224*224*64
h2 = self.encoder2(h1) #112*112*128
h3 = self.encoder3(h2) #56*56*256
h4 = self.encoder4(h3) #28*28*512
hx = self.pool4(h4)
hx = self.resb5_1(hx) #14*14*512
hx = self.resb5_2(hx)
h5 = self.resb5_3(hx)
hx = self.pool5(h5)
hx = self.resb6_1(hx) #7*7*512
hx = self.resb6_2(hx)
h6 = self.resb6_3(hx)
## -------------Bridge-------------
hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) #7*7*512
hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
## -------------Decoder-------------
hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg,h6),1))))#7*7*512
hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
hx = self.upscore2(hd6) # 7 -> 14
hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx,h5),1))))#14*14*512
hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
hx = self.upscore2(hd5) # 14 -> 28
hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx,h4),1))))#28*28*512
hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
hx = self.upscore2(hd4) # 28 -> 56
hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx,h3),1))))#56*56*256
hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
hx = self.upscore2(hd3) # 56 -> 112
hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx,h2),1))))#112*112*128
hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
hx = self.upscore2(hd2) # 112 -> 224
hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx,h1),1))))#224*224*64
hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
## -------------Side Output 7个阶段的最后一层上采样-------------
db = self.outconvb(hbg)
db = self.upscore6(db) # 7->224
d6 = self.outconv6(hd6)
d6 = self.upscore6(d6) # 7->224
d5 = self.outconv5(hd5)
d5 = self.upscore5(d5) # 14->224
d4 = self.outconv4(hd4)
d4 = self.upscore4(d4) # 28->224
d3 = self.outconv3(hd3)
d3 = self.upscore3(d3) # 56->224
d2 = self.outconv2(hd2)
d2 = self.upscore2(d2) # 112->224
d1 = self.outconv1(hd1) # 224
## -------------Refine Module 最后一层经过优化-------------
dout = self.refunet(d1) # 224
return F.sigmoid(dout), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6), F.sigmoid(db)
6.IOU
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
def _iou(pred, target, size_average = True):
b = pred.shape[0]#是照片序号
IoU = 0.0
for i in range(0,b):
#compute the IoU of the foreground
Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])#torch.sum()对输入的tensor数据的某一维度求和
#tensor([[1, 1, 1],[1, 1, 1]])
#a1 = torch.sum(a) a2 = torch.sum(a, dim=0) a3 = torch.sum(a, dim=1)
#tensor([[6.]]) tensor([[2., 2., 2.]]) tensor([[3.], [3.]])
Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
IoU1 = Iand1/Ior1
#IoU loss is (1-IoU1)
IoU = IoU + (1-IoU1)
return IoU/b
class IOU(torch.nn.Module):
def __init__(self, size_average = True):
super(IOU, self).__init__()
self.size_average = size_average
def forward(self, pred, target):
return _iou(pred, target, self.size_average)
7.SSIM
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma): #生成一维高斯核
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()#归一化
def create_window(window_size, channel):#通过一维高斯核生成二维
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)#升维
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
#expand参数中 第一个参数是输出通道 即卷积核数量,输出通道数=channel说明经过卷积后通道数不变
# 第二个参数是inchannel(输入通道,与输入图像相同的通道数)/groups,由于groups=channel,所以在这里等于1
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) #卷积求均值
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
# 在计算方差和协方差时用到了公式Var(X)=E[X^2]-E[X]^2, cov(X,Y)=E[XY]-E[X]E[Y].
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq #求方差
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 #求协方差
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel) #用初始化的channel创建window
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()#取img1的通道数
if channel == self.channel and self.window.data.type() == img1.data.type():#判断初始化通道与img1通道数 判断window的数据类型与img1数据类型
window = self.window
else:
window = create_window(self.window_size, channel) #若不相等,创建img1 channel的window
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1) #将window的数据类型转换为与img1相同
self.window = window #使用通过img channel创建的window
self.channel = channel #使用img1 channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def _logssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map))
ssim_map = -torch.log(ssim_map + 1e-8)
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class LOGSSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(LOGSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _logssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
print(create_window(11,1))
8.测试
img_path = path_data1 + os.listdir(path_data1)[0]
im1 = PIL.Image.open(path_data2 + "1.jpg") #原图
im1 = im1.convert("RGBA")
tmp = cv2.imread(img_path)
ret, tmp = cv2.threshold(tmp, 127, 255, cv2.THRESH_BINARY)
cv2.imwrite(img_path, tmp)
im2 = PIL.Image.open(img_path) #二值化后的黑白图
im2 = im2.convert("RGBA")
height, width = im1.size
for h in range(0, height):
for w in range(0, width):
b, g, r = im2.getpixel((h, w))[0], im2.getpixel(
(h, w))[1], im2.getpixel((h, w))[2]
if (b, g, r) == (0, 0, 0):
im1.putpixel((h, w), (255, 255, 255, 0))