-
Notifications
You must be signed in to change notification settings - Fork 211
/
Copy pathloss.py
executable file
·315 lines (261 loc) · 10.1 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu ([email protected])
###############################################################################
"""VGG Losses"""
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models
def gram_matrix(input_tensor):
"""
Compute Gram matrix
:param input_tensor: input tensor with shape
(batch_size, nbr_channels, height, width)
:return: Gram matrix of y
"""
(b, ch, h, w) = input_tensor.size()
features = input_tensor.view(b, ch, w * h)
features_t = features.transpose(1, 2)
# more efficient and formal way to avoid underflow for mixed precision training
input = torch.zeros(b, ch, ch).type(features.type())
gram = torch.baddbmm(input, features, features_t, beta=0, alpha=1./(ch * h * w), out=None)
# naive way to avoid underflow for mixed precision training
# features = features / (ch * h)
# gram = features.bmm(features_t) / w
# for fp32 training, it is also safe to use the following:
# gram = features.bmm(features_t) / (ch * h * w)
return gram
class PerceptualLoss(nn.Module):
"""
Perceptual Loss Module
"""
def __init__(self):
"""Init"""
super().__init__()
self.l1_loss = torch.nn.L1Loss()
self.mse_loss = torch.nn.MSELoss()
@staticmethod
def normalize_batch(batch, div_factor=255.):
"""
Normalize batch
:param batch: input tensor with shape
(batch_size, nbr_channels, height, width)
:param div_factor: normalizing factor before data whitening
:return: normalized data, tensor with shape
(batch_size, nbr_channels, height, width)
"""
# normalize using imagenet mean and std
mean = batch.data.new(batch.data.size())
std = batch.data.new(batch.data.size())
mean[:, 0, :, :] = 0.485
mean[:, 1, :, :] = 0.456
mean[:, 2, :, :] = 0.406
std[:, 0, :, :] = 0.229
std[:, 1, :, :] = 0.224
std[:, 2, :, :] = 0.225
batch = torch.div(batch, div_factor)
batch -= Variable(mean)
batch = torch.div(batch, Variable(std))
return batch
def forward(self, x, y):
"""
Forward
:param x: input tensor with shape
(batch_size, nbr_channels, height, width)
:param y: input tensor with shape
(batch_size, nbr_channels, height, width)
:return: l1 loss between the normalized data
"""
x = self.normalize_batch(x)
y = self.normalize_batch(y)
return self.l1_loss(x, y)
def make_vgg16_layers(style_avg_pool = False):
"""
make_vgg16_layers
Return a custom vgg16 feature module with avg pooling
"""
vgg16_cfg = [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M',
512, 512, 512, 'M', 512, 512, 512, 'M'
]
layers = []
in_channels = 3
for v in vgg16_cfg:
if v == 'M':
if style_avg_pool:
layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
class VGG16Partial(nn.Module):
"""
VGG16 partial model
"""
def __init__(self, vgg_path='~/.torch/vgg16-397923af.pth', layer_num=3):
"""
Init
:param layer_num: number of layers
"""
super().__init__()
vgg_model = models.vgg16()
vgg_model.features = make_vgg16_layers()
vgg_model.load_state_dict(
torch.load(vgg_path, map_location='cpu')
)
vgg_pretrained_features = vgg_model.features
assert layer_num > 0
assert isinstance(layer_num, int)
self.layer_num = layer_num
self.slice1 = torch.nn.Sequential()
for x in range(5): # 4
self.slice1.add_module(str(x), vgg_pretrained_features[x])
if self.layer_num > 1:
self.slice2 = torch.nn.Sequential()
for x in range(5, 10): # (4, 9)
self.slice2.add_module(str(x), vgg_pretrained_features[x])
if self.layer_num > 2:
self.slice3 = torch.nn.Sequential()
for x in range(10, 17): # (9, 16)
self.slice3.add_module(str(x), vgg_pretrained_features[x])
if self.layer_num > 3:
self.slice4 = torch.nn.Sequential()
for x in range(17, 24): # (16, 23)
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for param in self.parameters():
param.requires_grad = False
@staticmethod
def normalize_batch(batch, div_factor=1.0):
"""
Normalize batch
:param batch: input tensor with shape
(batch_size, nbr_channels, height, width)
:param div_factor: normalizing factor before data whitening
:return: normalized data, tensor with shape
(batch_size, nbr_channels, height, width)
"""
# normalize using imagenet mean and std
mean = batch.data.new(batch.data.size())
std = batch.data.new(batch.data.size())
mean[:, 0, :, :] = 0.485
mean[:, 1, :, :] = 0.456
mean[:, 2, :, :] = 0.406
std[:, 0, :, :] = 0.229
std[:, 1, :, :] = 0.224
std[:, 2, :, :] = 0.225
batch = torch.div(batch, div_factor)
batch -= Variable(mean)
batch = torch.div(batch, Variable(std))
return batch
def forward(self, x):
"""
Forward, get features used for perceptual loss
:param x: input tensor with shape
(batch_size, nbr_channels, height, width)
:return: list of self.layer_num feature maps used to compute the
perceptual loss
"""
h = self.slice1(x)
h1 = h
output = []
if self.layer_num == 1:
output = [h1]
elif self.layer_num == 2:
h = self.slice2(h)
h2 = h
output = [h1, h2]
elif self.layer_num == 3:
h = self.slice2(h)
h2 = h
h = self.slice3(h)
h3 = h
output = [h1, h2, h3]
elif self.layer_num >= 4:
h = self.slice2(h)
h2 = h
h = self.slice3(h)
h3 = h
h = self.slice4(h)
h4 = h
output = [h1, h2, h3, h4]
return output
# perceptual loss and (spatial) style loss
class VGG16PartialLoss(PerceptualLoss):
"""
VGG16 perceptual loss
"""
def __init__(self, l1_alpha=5.0, perceptual_alpha=0.05, style_alpha=120,
smooth_alpha=0, feat_num=3, vgg_path='~/.torch/vgg16-397923af.pth'):
"""
Init
:param l1_alpha: weight of the l1 loss
:param perceptual_alpha: weight of the perceptual loss
:param style_alpha: weight of the style loss
:param smooth_alpha: weight of the regularizer
:param feat_num: number of feature maps
"""
super().__init__()
self.vgg16partial = VGG16Partial(vgg_path=vgg_path).eval()
self.loss_fn = torch.nn.L1Loss(size_average=True)
self.l1_weight = l1_alpha
self.vgg_weight = perceptual_alpha
self.style_weight = style_alpha
self.regularize_weight = smooth_alpha
self.dividor = 1
self.feat_num = feat_num
def forward(self, output0, target0):
"""
Forward
assuming both output0 and target0 are in the range of [0, 1]
:param output0: output of a model, tensor with shape
(batch_size, nbr_channels, height, width)
:param target0: target, tensor with shape
(batch_size, nbr_channels, height, width)
:return: total perceptual loss
"""
y = self.normalize_batch(target0, self.dividor)
x = self.normalize_batch(output0, self.dividor)
# L1 loss
l1_loss = self.l1_weight * (torch.abs(x - y).mean())
vgg_loss = 0
style_loss = 0
smooth_loss = 0
# VGG
if self.vgg_weight != 0 or self.style_weight != 0:
yc = Variable(y.data)
with torch.no_grad():
groundtruth = self.vgg16partial(yc)
generated = self.vgg16partial(x)
# vgg loss: VGG content loss
if self.vgg_weight > 0:
# for m in range(0, len(generated)):
for m in range(len(generated) - self.feat_num, len(generated)):
gt_data = Variable(groundtruth[m].data, requires_grad=False)
vgg_loss += (
self.vgg_weight * self.loss_fn(generated[m], gt_data)
)
# style loss: Gram matrix loss
if self.style_weight > 0:
# for m in range(0, len(generated)):
for m in range(len(generated) - self.feat_num, len(generated)):
gt_style = gram_matrix(
Variable(groundtruth[m].data, requires_grad=False))
gen_style = gram_matrix(generated[m])
style_loss += (
self.style_weight * self.loss_fn(gen_style, gt_style)
)
# smooth term
if self.regularize_weight != 0:
smooth_loss += self.regularize_weight * (
torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]).mean() +
torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]).mean()
)
tot = l1_loss + vgg_loss + style_loss + smooth_loss
return tot, vgg_loss, style_loss