如果你是攝影愛好者,你可能對濾鏡并不陌生。它可以改變照片的色彩風(fēng)格,使風(fēng)景照片變得更清晰或肖像照片皮膚變白。但是,一個濾鏡通常只會改變照片的一個方面。要為照片應(yīng)用理想的風(fēng)格,您可能需要嘗試多種不同的濾鏡組合。這個過程與調(diào)整模型的超參數(shù)一樣復(fù)雜。
在本節(jié)中,我們將利用 CNN 的分層表示將一幅圖像的風(fēng)格自動應(yīng)用到另一幅圖像,即 風(fēng)格遷移 (Gatys等人,2016 年)。此任務(wù)需要兩張輸入圖像:一張是內(nèi)容圖像,另一張是風(fēng)格圖像。我們將使用神經(jīng)網(wǎng)絡(luò)修改內(nèi)容圖像,使其在風(fēng)格上接近風(fēng)格圖像。例如 圖14.12.1中的內(nèi)容圖片是我們在西雅圖郊區(qū)雷尼爾山國家公園拍攝的風(fēng)景照,而風(fēng)格圖是一幅以秋天的橡樹為主題的油畫。在輸出的合成圖像中,應(yīng)用了樣式圖像的油畫筆觸,使顏色更加鮮艷,同時保留了內(nèi)容圖像中對象的主要形狀。
圖 14.12.1給定內(nèi)容和風(fēng)格圖像,風(fēng)格遷移輸出合成圖像。
14.12.1。方法
圖 14.12.2用一個簡化的例子說明了基于 CNN 的風(fēng)格遷移方法。首先,我們將合成圖像初始化為內(nèi)容圖像。這張合成圖像是風(fēng)格遷移過程中唯一需要更新的變量,即訓(xùn)練期間要更新的模型參數(shù)。然后我們選擇一個預(yù)訓(xùn)練的 CNN 來提取圖像特征并在訓(xùn)練期間凍結(jié)其模型參數(shù)。這種深度 CNN 使用多層來提取圖像的層次特征。我們可以選擇其中一些層的輸出作為內(nèi)容特征或樣式特征。如圖14.12.2舉個例子。這里的預(yù)訓(xùn)練神經(jīng)網(wǎng)絡(luò)有 3 個卷積層,其中第二層輸出內(nèi)容特征,第一層和第三層輸出風(fēng)格特征。
圖 14.12.2基于 CNN 的風(fēng)格遷移過程。實線表示正向傳播方向,虛線表示反向傳播。
接下來,我們通過正向傳播(實線箭頭方向)計算風(fēng)格遷移的損失函數(shù),并通過反向傳播(虛線箭頭方向)更新模型參數(shù)(輸出的合成圖像)。風(fēng)格遷移中常用的損失函數(shù)由三部分組成:(i)內(nèi)容損失使合成圖像和內(nèi)容圖像在內(nèi)容特征上接近;(ii)風(fēng)格損失使得合成圖像和風(fēng)格圖像在風(fēng)格特征上接近;(iii) 總變差損失有助于減少合成圖像中的噪聲。最后,當(dāng)模型訓(xùn)練結(jié)束后,我們輸出風(fēng)格遷移的模型參數(shù),生成最終的合成圖像。
下面,我們將通過一個具體的實驗來解釋風(fēng)格遷移的技術(shù)細(xì)節(jié)。
14.12.2。閱讀內(nèi)容和樣式圖像
首先,我們閱讀內(nèi)容和樣式圖像。從它們打印的坐標(biāo)軸,我們可以看出這些圖像具有不同的尺寸。
%matplotlib inline import torch import torchvision from torch import nn from d2l import torch as d2l d2l.set_figsize() content_img = d2l.Image.open('../img/rainier.jpg') d2l.plt.imshow(content_img);
style_img = d2l.Image.open('../img/autumn-oak.jpg') d2l.plt.imshow(style_img);
%matplotlib inline from mxnet import autograd, gluon, image, init, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() d2l.set_figsize() content_img = image.imread('../img/rainier.jpg') d2l.plt.imshow(content_img.asnumpy());
style_img = image.imread('../img/autumn-oak.jpg') d2l.plt.imshow(style_img.asnumpy());
14.12.3。預(yù)處理和后處理
下面,我們定義了兩個用于預(yù)處理和后處理圖像的函數(shù)。該preprocess函數(shù)對輸入圖像的三個 RGB 通道中的每一個進行標(biāo)準(zhǔn)化,并將結(jié)果轉(zhuǎn)換為 CNN 輸入格式。該postprocess函數(shù)將輸出圖像中的像素值恢復(fù)為標(biāo)準(zhǔn)化前的原始值。由于圖像打印功能要求每個像素都有一個從0到1的浮點值,我們將任何小于0或大于1的值分別替換為0或1。
rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225]) def preprocess(img, image_shape): transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(image_shape), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)]) return transforms(img).unsqueeze(0) def postprocess(img): img = img[0].to(rgb_std.device) img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
rgb_mean = np.array([0.485, 0.456, 0.406]) rgb_std = np.array([0.229, 0.224, 0.225]) def preprocess(img, image_shape): img = image.imresize(img, *image_shape) img = (img.astype('float32') / 255 - rgb_mean) / rgb_std return np.expand_dims(img.transpose(2, 0, 1), axis=0) def postprocess(img): img = img[0].as_in_ctx(rgb_std.ctx) return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)
14.12.4。提取特征
我們使用在 ImageNet 數(shù)據(jù)集上預(yù)訓(xùn)練的 VGG-19 模型來提取圖像特征( Gatys et al. , 2016 )。
pretrained_net = torchvision.models.vgg19(pretrained=True)
pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True)
為了提取圖像的內(nèi)容特征和風(fēng)格特征,我們可以選擇VGG網(wǎng)絡(luò)中某些層的輸出。一般來說,越靠近輸入層越容易提取圖像的細(xì)節(jié),反之越容易提取圖像的全局信息。為了避免在合成圖像中過度保留內(nèi)容圖像的細(xì)節(jié),我們選擇了一個更接近輸出的VGG層作為內(nèi)容層來輸出圖像的內(nèi)容特征。我們還選擇不同 VGG 層的輸出來提取局部和全局風(fēng)格特征。這些圖層也稱為樣式圖層。如第 8.2 節(jié)所述,VGG 網(wǎng)絡(luò)使用 5 個卷積塊。在實驗中,我們選擇第四個卷積塊的最后一個卷積層作為內(nèi)容層,每個卷積塊的第一個卷積層作為樣式層。這些層的索引可以通過打印pretrained_net實例來獲得。
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
當(dāng)使用 VGG 層提取特征時,我們只需要使用從輸入層到最接近輸出層的內(nèi)容層或樣式層的所有那些。讓我們構(gòu)建一個新的網(wǎng)絡(luò)實例net,它只保留所有用于特征提取的 VGG 層。
net = nn.Sequential(*[pretrained_net.features[i] for i in range(max(content_layers + style_layers) + 1)])
net = nn.Sequential() for i in range(max(content_layers + style_layers) + 1): net.add(pretrained_net.features[i])
給定輸入X,如果我們簡單地調(diào)用前向傳播 net(X),我們只能得到最后一層的輸出。由于我們還需要中間層的輸出,因此我們需要逐層計算并保留內(nèi)容層和樣式層的輸出。
def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles
def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles
下面定義了兩個函數(shù):get_contents函數(shù)從內(nèi)容圖像中提取內(nèi)容特征,函數(shù)get_styles從風(fēng)格圖像中提取風(fēng)格特征。由于在訓(xùn)練期間不需要更新預(yù)訓(xùn)練 VGG 的模型參數(shù),我們甚至可以在訓(xùn)練開始之前提取內(nèi)容和風(fēng)格特征。由于合成圖像是一組需要更新的模型參數(shù)以進行風(fēng)格遷移,因此我們只能extract_features 在訓(xùn)練時通過調(diào)用函數(shù)來提取合成圖像的內(nèi)容和風(fēng)格特征。
def get_contents(image_shape, device): content_X = preprocess(content_img, image_shape).to(device) contents_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, contents_Y def get_styles(image_shape, device): style_X = preprocess(style_img, image_shape).to(device) _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y
def get_contents(image_shape, device): content_X = preprocess(content_img, image_shape).copyto(device) contents_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, contents_Y def get_styles(image_shape, device): style_X = preprocess(style_img, image_shape).copyto(device) _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y
14.12.5。定義損失函數(shù)
現(xiàn)在我們將描述風(fēng)格遷移的損失函數(shù)。損失函數(shù)包括內(nèi)容損失、風(fēng)格損失和全變損失。
14.12.5.1。內(nèi)容丟失
類似于線性回歸中的損失函數(shù),內(nèi)容損失通過平方損失函數(shù)衡量合成圖像和內(nèi)容圖像之間內(nèi)容特征的差異。平方損失函數(shù)的兩個輸入都是該extract_features函數(shù)計算的內(nèi)容層的輸出。
def content_loss(Y_hat, Y): # We detach the target content from the tree used to dynamically compute # the gradient: this is a stated value, not a variable. Otherwise the loss # will throw an error. return torch.square(Y_hat - Y.detach()).mean()
def content_loss(Y_hat, Y): return np.square(Y_hat - Y).mean()
14.12.5.2。風(fēng)格損失
風(fēng)格損失與內(nèi)容損失類似,也是使用平方損失函數(shù)來衡量合成圖像與風(fēng)格圖像之間的風(fēng)格差異。為了表達(dá)任何樣式層的樣式輸出,我們首先使用函數(shù)extract_features來計算樣式層輸出。假設(shè)輸出有 1 個示例,c渠道,高度 h, 和寬度w,我們可以將此輸出轉(zhuǎn)換為矩陣 X和c行和hw列。這個矩陣可以被認(rèn)為是串聯(lián)c載體 x1,…,xc, 其中每一個的長度為hw. 在這里,矢量xi表示頻道的風(fēng)格特征i.
在這些向量的 Gram 矩陣中XX?∈Rc×c, 元素 xij在排隊i和專欄j是向量的點積xi和xj. 表示渠道風(fēng)格特征的相關(guān)性i和 j. 我們使用這個 Gram 矩陣來表示任何樣式層的樣式輸出。請注意,當(dāng)值hw越大,它可能會導(dǎo)致 Gram 矩陣中的值越大。還要注意,Gram矩陣的高和寬都是通道數(shù)c. 為了讓風(fēng)格損失不受這些值的影響,gram 下面的函數(shù)將 Gram 矩陣除以其元素的數(shù)量,即chw.
def gram(X): num_channels, n = X.shape[1], X.numel() // X.shape[1] X = X.reshape((num_channels, n)) return torch.matmul(X, X.T) / (num_channels * n)
def gram(X): num_channels, n = X.shape[1], d2l.size(X) // X.shape[1] X = X.reshape((num_channels, n)) return np.dot(X, X.T) / (num_channels * n)
顯然,風(fēng)格損失的平方損失函數(shù)的兩個格拉姆矩陣輸入是基于合成圖像和風(fēng)格圖像的風(fēng)格層輸出。這里假設(shè) gram_Y基于風(fēng)格圖像的 Gram 矩陣已經(jīng)預(yù)先計算好了。
def style_loss(Y_hat, gram_Y): return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
def style_loss(Y_hat, gram_Y): return np.square(gram(Y_hat) - gram_Y).mean()
14.12.5.3。總變異損失
有時,學(xué)習(xí)到的合成圖像有很多高頻噪聲,即特別亮或特別暗的像素。一種常見的降噪方法是全變差去噪。表示為 xi,j坐標(biāo)處的像素值(i,j). 減少總變異損失
(14.12.1)∑i,j|xi,j?xi+1,j|+|xi,j?xi,j+1|
使合成圖像上相鄰像素的值更接近。
def tv_loss(Y_hat): return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
def tv_loss(Y_hat): return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
14.12.5.4。損失函數(shù)
風(fēng)格遷移的損失函數(shù)是內(nèi)容損失、風(fēng)格損失和總變異損失的加權(quán)和。通過調(diào)整這些權(quán)重超參數(shù),我們可以在合成圖像的內(nèi)容保留、風(fēng)格遷移和降噪之間取得平衡。
content_weight, style_weight, tv_weight = 1, 1e4, 10 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): # Calculate the content, style, and total variance losses respectively contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] tv_l = tv_loss(X) * tv_weight # Add up all the losses l = sum(styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l
content_weight, style_weight, tv_weight = 1, 1e4, 10 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): # Calculate the content, style, and total variance losses respectively contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] tv_l = tv_loss(X) * tv_weight # Add up all the losses l = sum(styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l
14.12.6. 初始化合成圖像
在風(fēng)格遷移中,合成圖像是訓(xùn)練期間唯一需要更新的變量。因此,我們可以定義一個簡單的模型, SynthesizedImage并將合成圖像作為模型參數(shù)。在這個模型中,前向傳播只返回模型參數(shù)。
class SynthesizedImage(nn.Module): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = nn.Parameter(torch.rand(*img_shape)) def forward(self): return self.weight
class SynthesizedImage(nn.Block): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = self.params.get('weight', shape=img_shape) def forward(self): return self.weight.data()
接下來,我們定義get_inits函數(shù)。此函數(shù)創(chuàng)建一個合成圖像模型實例并將其初始化為 image X。styles_Y_gram在訓(xùn)練之前計算各種樣式層的樣式圖像的 Gram 矩陣 。
def get_inits(X, device, lr, styles_Y): gen_img = SynthesizedImage(X.shape).to(device) gen_img.weight.data.copy_(X.data) trainer = torch.optim.Adam(gen_img.parameters(), lr=lr) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer
def get_inits(X, device, lr, styles_Y): gen_img = SynthesizedImage(X.shape) gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True) trainer = gluon.Trainer(gen_img.collect_params(), 'adam', {'learning_rate': lr}) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer
14.12.7. 訓(xùn)練
在訓(xùn)練風(fēng)格遷移模型時,我們不斷提取合成圖像的內(nèi)容特征和風(fēng)格特征,并計算損失函數(shù)。下面定義了訓(xùn)練循環(huán)。
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): trainer.zero_grad() contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step() scheduler.step() if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X)) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], ylim=[0, 20], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): with autograd.record(): contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step(1) if (epoch + 1) % lr_decay_epoch == 0: trainer.set_learning_rate(trainer.learning_rate * 0.8) if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X).asnumpy()) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X
現(xiàn)在我們開始訓(xùn)練模型。我們將內(nèi)容和樣式圖像的高度和寬度重新調(diào)整為 300 x 450 像素。我們使用內(nèi)容圖像來初始化合成圖像。
device, image_shape = d2l.try_gpu(), (300, 450) # PIL Image (h, w) net = net.to(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
device, image_shape = d2l.try_gpu(), (450, 300) net.collect_params().reset_ctx(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.9, 500, 50)
我們可以看到,合成圖保留了內(nèi)容圖的景物和物體,同時傳遞了風(fēng)格圖的顏色。例如,合成圖像具有風(fēng)格圖像中的顏色塊。其中一些塊甚至具有筆觸的微妙紋理。
14.12.8。概括
風(fēng)格遷移中常用的損失函數(shù)由三部分組成:(i)內(nèi)容損失使合成圖像和內(nèi)容圖像在內(nèi)容特征上接近;(ii) 風(fēng)格損失使得合成圖像和風(fēng)格圖像在風(fēng)格特征上接近;(iii) 總變差損失有助于減少合成圖像中的噪聲。
我們可以使用預(yù)訓(xùn)練的 CNN 提取圖像特征并最小化損失函數(shù),以在訓(xùn)練期間不斷更新合成圖像作為模型參數(shù)。
我們使用 Gram 矩陣來表示樣式層的樣式輸出。
14.12.9。練習(xí)
當(dāng)您選擇不同的內(nèi)容和樣式層時,輸出如何變化?
調(diào)整損失函數(shù)中的權(quán)重超參數(shù)。輸出是保留更多內(nèi)容還是噪音更少?
使用不同的內(nèi)容和樣式圖像。你能創(chuàng)造出更有趣的合成圖像嗎?
我們可以對文本應(yīng)用樣式轉(zhuǎn)換嗎?提示:您可以參考Hu等人的調(diào)查論文。( 2022 )。
-
cnn
+關(guān)注
關(guān)注
3文章
353瀏覽量
22248 -
pytorch
+關(guān)注
關(guān)注
2文章
808瀏覽量
13249
發(fā)布評論請先 登錄
相關(guān)推薦
評論