媽媽小時候已經有彩色照片了,不過那些照片,還是照相館的人類手動上色的。
幾十年之後,人們已經開始培育深度神經網絡,來給老照片和老電影上色了。
來自哈佛大學的Luke Melas-Kyriazi (我叫他盧克吧) ,用自己訓練的神經網絡,把卓別林變成了彩色的卓別林,清新自然。
作為一隻哈佛學霸,盧克還為鑽研機器學習的小夥伴們寫了一個基於PyTorch的教程。
雖然教程裡的模型比給卓別林用的模型要簡約一些,但效果也是不錯了。
問題是什麼?盧克說,給黑白照片上色這個問題的難點在於,它是多模態的——與一幅灰度圖像對應的合理彩色圖像,並不唯一。
△ 這並不是正確示範傳統模型需要輸入許多額外信息,來輔助上色。
而深度神經網絡,除了灰度圖像之外,不需要任何額外輸入,就可以完成上色。
在彩色圖像裡,每個像素包含三個值,即亮度、飽和度以及色調。
而灰度圖像,並無飽和度和色調可言,只有亮度一個值。
所以,模型要用一組數據,生成另外兩足數據。換句話說,以灰度圖像為起點,推斷出對應的彩色圖像。
為了簡單,這裡只做了256 x 256像素的圖像上色。輸出的數據量則是256 x 256 x 2。
關於顏色表示,盧克用的是LAB色彩空間,它跟RGB系統包含的信息是一樣的。
但對程序猿來說,前者比較方便把亮度和其他兩項分離開來。
數據也不難獲得,盧克用了MIT Places數據集,中的一部分。內容就是校園裡的一些地標和風景。然後轉換成黑白圖像,就可以了。以下為數據搬運代碼——
1
2!wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz
3!tar -xzf testSetPlaces205_resize.tar.gz
1
2import os
3os.makedirs('images/train/class/', exist_ok=True)
4os.makedirs('images/val/class/', exist_ok=True)
5for i, file in enumerate(os.listdir('testSet_resize')):
6 if i < 1000:
7 os.rename('testSet_resize/' + file, 'images/val/class/' + file)
8 else:
9 os.rename('testSet_resize/' + file, 'images/train/class/' + file)
1
2from IPython.display import Image, display
3display(Image(filename='images/val/class/84b3ccd8209a4db1835988d28adfed4c.jpg'))
搭建模型和訓練模型是在PyTorch裡完成的。
還用了torchvishion,這是一套在PyTorch上處理圖像和視頻的工具。
另外,scikit-learn能完成圖片在RGB和LAB色彩空間之間的轉換。
1
2!pip install torch torchvision matplotlib numpy scikit-image pillow==4.1.1
1
2import numpy as np
3import matplotlib.pyplot as plt
4%matplotlib inline
5
6from skimage.color import lab2rgb, rgb2lab, rgb2gray
7from skimage import io
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12
13import torchvision.models as models
14from torchvision import datasets, transforms
15
16import os, shutil, time
1
2use_gpu = torch.cuda.is_available()
神經網絡裡面,第一部分是幾層用來提取圖像特徵;第二部分是一些反卷積層 (Deconvolutional Layers) ,用來給那些特徵增加解析度。
具體來說,第一部分用的是ResNet-18,這是一個圖像分類網絡,有18層,以及一些殘差連接 (Residual Connections) 。
給第一層做些修改,它就可以接受灰度圖像輸入了。然後把第6層之後的都去掉。
然後,用代碼來定義一下這個模型。
從神經網絡的第二部分 (就是那些上採樣層) 開始。
1class ColorizationNet(nn.Module):
2 def __init__(self, input_size=128):
3 super(ColorizationNet, self).__init__()
4 MIDLEVEL_FEATURE_SIZE = 128
5
6
7 resnet = models.resnet18(num_classes=365)
8
9 resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
10
11 self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
12
13
14 self.upsample = nn.Sequential(
15 nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
16 nn.BatchNorm2d(128),
17 nn.ReLU(),
18 nn.Upsample(scale_factor=2),
19 nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
20 nn.BatchNorm2d(64),
21 nn.ReLU(),
22 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
23 nn.BatchNorm2d(64),
24 nn.ReLU(),
25 nn.Upsample(scale_factor=2),
26 nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
27 nn.BatchNorm2d(32),
28 nn.ReLU(),
29 nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
30 nn.Upsample(scale_factor=2)
31 )
32
33 def forward(self, input):
34
35
36 midlevel_features = self.midlevel_resnet(input)
37
38
39 output = self.upsample(midlevel_features)
40 return output
下一步,創建模型吧。
1model = ColorizationNet()
它是怎麼訓練的?預測每個像素的色值,用的是回歸 (Regression) 的方法。
損失函數 (Loss Function)所以,用了一個均方誤差 (MSE) 損失函數——讓預測的色值與參考標準 (Ground Truth) 之間的距離平方最小化。
1criterion = nn.MSELoss()
優化損失函數
這裡是用Adam Optimizer優化的。
1optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)
加載數據
用torchtext加載數據。首先定義一個專屬的數據加載器 (DataLoader) ,來完成RGB到LAB空間的轉換。
1class GrayscaleImageFolder(datasets.ImageFolder):
2 '''Custom images folder, which converts images to grayscale before loading'''
3 def __getitem__(self, index):
4 path, target = self.imgs[index]
5 img = self.loader(path)
6 if self.transform is not None:
7 img_original = self.transform(img)
8 img_original = np.asarray(img_original)
9 img_lab = rgb2lab(img_original)
10 img_lab = (img_lab + 128) / 255
11 img_ab = img_lab[:, :, 1:3]
12 img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
13 img_original = rgb2gray(img_original)
14 img_original = torch.from_numpy(img_original).unsqueeze(0).float()
15 if self.target_transform is not None:
16 target = self.target_transform(target)
17 return img_original, img_ab, target
再來,就是定義訓練數據和驗證數據的轉換。
1
2train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
3train_imagefolder = GrayscaleImageFolder('images/train', train_transforms)
4train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=64, shuffle=True)
5
6
7val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
8val_imagefolder = GrayscaleImageFolder('images/val' , val_transforms)
9val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=64, shuffle=False)
訓練開始之前,要把輔助函數寫好,來追蹤訓練損失,並把圖像轉回RGB形式。
1class AverageMeter(object):
2 '''A handy class from the PyTorch ImageNet tutorial'''
3 def __init__(self):
4 self.reset()
5 def reset(self):
6 self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
7 def update(self, val, n=1):
8 self.val = val
9 self.sum += val * n
10 self.count += n
11 self.avg = self.sum / self.count
12
13def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
14 '''Show/save rgb image from grayscale and ab channels
15 Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
16 plt.clf()
17 color_image = torch.cat((grayscale_input, ab_input), 0).numpy()
18 color_image = color_image.transpose((1, 2, 0))
19 color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
20 color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
21 color_image = lab2rgb(color_image.astype(np.float64))
22 grayscale_input = grayscale_input.squeeze().numpy()
23 if save_path is not None and save_name is not None:
24 plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
25 plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))
不用反向傳播 (Back Propagation),直接用torch.no_grad() 跑模型。
1def validate(val_loader, model, criterion, save_images, epoch):
2 model.eval()
3
4
5 batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
6
7 end = time.time()
8 already_saved_images = False
9 for i, (input_gray, input_ab, target) in enumerate(val_loader):
10 data_time.update(time.time() - end)
11
12
13 if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda()
14
15
16 output_ab = model(input_gray)
17 loss = criterion(output_ab, input_ab)
18 losses.update(loss.item(), input_gray.size(0))
19
20
21 if save_images and not already_saved_images:
22 already_saved_images = True
23 for j in range(min(len(output_ab), 10)):
24 save_path = {'grayscale': 'outputs/gray/', 'colorized': 'outputs/color/'}
25 save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch)
26 to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)
27
28
29 batch_time.update(time.time() - end)
30 end = time.time()
31
32
33 if i % 25 == 0:
34 print('Validate: [{0}/{1}]\t'
35 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
36 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
37 i, len(val_loader), batch_time=batch_time, loss=losses))
38
39 print('Finished validation.')
40 return losses.avg
訓練
用loss.backward(),用上反向傳播。寫一下訓練數據跑一遍 (one epoch) 用的函數。
1def train(train_loader, model, criterion, optimizer, epoch):
2 print('Starting training epoch {}'.format(epoch))
3 model.train()
4
5
6 batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
7
8 end = time.time()
9 for i, (input_gray, input_ab, target) in enumerate(train_loader):
10
11
12 if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda()
13
14
15 data_time.update(time.time() - end)
16
17
18 output_ab = model(input_gray)
19 loss = criterion(output_ab, input_ab)
20 losses.update(loss.item(), input_gray.size(0))
21
22
23 optimizer.zero_grad()
24 loss.backward()
25 optimizer.step()
26
27
28 batch_time.update(time.time() - end)
29 end = time.time()
30
31
32 if i % 25 == 0:
33 print('Epoch: [{0}][{1}/{2}]\t'
34 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
35 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
36 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
37 epoch, i, len(train_loader), batch_time=batch_time,
38 data_time=data_time, loss=losses))
39
40 print('Finished training epoch {}'.format(epoch))
然後,定義一個訓練迴路 (Training Loop) ,跑一百遍訓練數據。從Epoch 0開始訓練。
1
2if use_gpu:
3 criterion = criterion.cuda()
4 model = model.cuda()
1
2os.makedirs('outputs/color', exist_ok=True)
3os.makedirs('outputs/gray', exist_ok=True)
4os.makedirs('checkpoints', exist_ok=True)
5save_images = True
6best_losses = 1e10
7epochs = 100
1
2for epoch in range(epochs):
3
4 train(train_loader, model, criterion, optimizer, epoch)
5 with torch.no_grad():
6 losses = validate(val_loader, model, criterion, save_images, epoch)
7
8 if losses < best_losses:
9 best_losses = losses
10 torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))
是時候看看修煉成果了,所以,複製一下這段代碼。
1
2import matplotlib.image as mpimg
3image_pairs = [('outputs/color/img-2-epoch-0.jpg', 'outputs/gray/img-2-epoch-0.jpg'),
4 ('outputs/color/img-7-epoch-0.jpg', 'outputs/gray/img-7-epoch-0.jpg')]
5for c, g in image_pairs:
6 color = mpimg.imread(c)
7 gray = mpimg.imread(g)
8 f, axarr = plt.subplots(1, 2)
9 f.set_size_inches(15, 15)
10 axarr[0].imshow(gray, cmap='gray')
11 axarr[1].imshow(color)
12 axarr[0].axis('off'), axarr[1].axis('off')
13 plt.show()
效果還是很自然的,雖然生成的彩色圖像不是那麼明麗。
盧克說,問題是多模態的,所以損失函數還是值得推敲。
比如,一條灰色裙子可以是藍色也可以是紅色。如果模型選擇的顏色和參考標準不同,就會受到嚴厲的懲罰。
這樣一來,模型就會選擇哪些不會被判為大錯特錯的顏色,而不太選擇非常顯眼明亮的顏色。
沒時間怎麼辦?盧克還把一隻訓練好的AI放了出來,不想從零開始訓練的小夥伴們,也可以直接感受他的訓練成果,只要用以下代碼下載就好了。
1
2!wget https://www.dropbox.com/s/kz76e7gv2ivmu8p/model-epoch-93.pth
3
1
2pretrained = torch.load('model-epoch-93.pth', map_location=lambda storage, loc: storage)
3model.load_state_dict(pretrained)
1
2save_images = True
3with torch.no_grad():
4 validate(val_loader, model, criterion, save_images, 0)
如果想要更加有聲有色的結局,就不能繼續偷懶了。盧克希望大家沿著他精心鋪就的路,走到更遠的地方。
要替換當前的損失函數,可以參考Zhang et al. (2017):
https://richzhang.github.io/ideepcolor/
無監督學習的上色大法,可以參考Larsson et al. (2017):
http://people.cs.uchicago.edu/~larsson/color-proxy/
另外,可以做個手機應用,就像谷歌在I/O大會上發布的著色軟體那樣。
黑白電影,也可以自己去嘗試,一幀一幀地上色。
這裡有卓別林用到的完整代碼:
https://github.com/lukemelas/Automatic-Image-Colorization/
量子位正在招募編輯/記者,工作地點在北京中關村。期待有才氣、有熱情的同學加入我們!相關細節,請在量子位公眾號(QbitAI)對話界面,回復「招聘」兩個字。