您的位置:首页 > 理论基础

[pytorch] 计算图像的一阶导/梯度/gradient

2018-01-30 18:43 197 查看

[pytorch] 计算图像的一阶导/梯度/gradient

在图像转换任务中常见的total variation loss(tvloss,总变分,一般作为平滑的规则化项)需要对图像的梯度求平方和。

style-transfer系的github项目,tvloss求法如下:

class TVLoss(torch.nn.Module):
def __init__(self):
super(TVLoss,self).__init__()

def forward(self,x):
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:,:,1:,:])
count_w = self._tensor_size(x[:,:,:,1:])
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return h_tv/count_h + w_tv/count_w

def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]


根据上面可以看出,这个求法避免了 (x[:,:,1:,:]-x[:,:,:h_x-1,:])与(x[:,:,:,1:]-x[:,:,:,:w_x-1])维度不一样的问题。

那么如果想要求得梯度作为变量怎么办咧?

我参考了numpy源码的写法,尝试利用torch.nn.functional.pad()这个函数实现了一个一阶导的版本,其中dx取值为2个像素:

import torch
import torch.nn.functional as F
def gradient_1order(x,h_x=None,w_x=None):
if h_x is None and w_x is None:
h_x = x.size()[2]
w_x = x.size()[3]
r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:]
l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x]
t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :]
b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :]
xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
return xgrad


测试

import torch
npim = np.array(im,dtype=np.float32)
import numpy as np
import PIL.Image as Image

def gradient_numpy(img):
# input : Image/nparray 0~255
changeflag = isinstance(img, Image.Image)
if changeflag:
img = np.array(img)
dx,dy = np.gradient(img/255,edge_order=1)
img = (np.sqrt(dx**2 + dy**2)*255).astype(np.uint8)
if changeflag:
img = Image.fromarray(img)
return img

im = Image.open('text.jpg').convert('L').crop([1,1,200,200])
# fig 1
im.show()
# fig 2
gradient_numpy(im).show()

npim = np.array(im,dtype=np.float32)
tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0)
gradient_1order(tim)
npgrad = t.squeeze(0).squeeze(0).data.clamp(0,255).numpy()
# fig 3
Image.fromarray(npgrad.astype('uint8')).show()








从效果上看是没问题的,但是遇到个我不解的情况:图像pair之间如果用l2loss或l1loss都可以用作生成网络的loss,但为什么我用图像的梯度pair做l2loss或l1loss就不可以呢?反向传播的时候生成网络直接梯度爆炸了。。。

并不是只用了这个梯度l2loss,而是不管这个loss权重多小,学习率多低,都会爆炸,不是很明白。。。。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息