CIFAR10でAutoencoder(CNNで)試してみる(Pytorchで)
ちょっと前にCIFAR10でカスタムデータセットのクラスを作ってみましたが、今回はそれを使ってCIFAR10のオートエンコーダの実験をしてみました。MNISTではどうやってもうまくいきそうな気がしますので、CIFAR10で実験です。
フォルダ構成はこんなと想定してください。cifar10_imgフォルダ以下のカテゴリ名のフォルダ以下に、PNGファイルがたくさん保存されています。

今回はairplaneについてのCNNを利用したAutoencoderができるか試してみます。
コード
まずはモデルです。ちなみにいくつかモデル試してみましたが、どのモデルも結果としてはイマイチであまり違いが出ませんでしたね。。ちなみにこれはSony NNCのサンプルプロジェクトにインスパイアされたモデルです。
#model.py
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 8, 5, stride=2, padding=2),
nn.PReLU(),
nn.Conv2d(8, 8, 5, stride=2, padding=2),
nn.PReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(8, 8, 6, stride=2, padding=2),
nn.PReLU(),
nn.ConvTranspose2d(8, 3, 6, stride=2, padding=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
データセットの取り方はこんなですね。airplaneフォルダの画像を読み込むだけですね。ラベルはAutoencoderなのでいらないですが、とりあえず0にしときます。
#my_dataset.py
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision import transforms
import glob
import os
class my_dataset(Dataset):
def __init__(self, img_path,transform=None):
image_paths = glob.glob(img_path + '/*.png')
labels = os.path.basename(img_path)
self.image_paths = image_paths
self.labels = 0
self.transform = transform
def __getitem__(self, index):
path = self.image_paths[index]
#画像読み込み。
img = Image.open(path)
#transform事前処理実施
if self.transform is not None:
img = self.transform(img)
label=self.labels
image_path=self.image_paths[index]
return img,label
def __len__(self):
return len(self.image_paths)
if __name__ == '__main__':
transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
#データセット作成
dataset = my_dataset("../cifar10_img/airplane",transform)
#dataloader化
dataloader = DataLoader(dataset, batch_size=4)
#データローダの中身確認
for img,label in dataloader:
print('label=',label)
print('image_path=',img.shape)
訓練コードですね。Autoencoder時の損失関数は 「nn.BCELoss()」ですね。
#train.py
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import model
import my_dataset
def train():
net = model.Net()
transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
trainset = my_dataset.my_dataset("../cifar10_img/airplane",transform)
trainloader = DataLoader(trainset, batch_size=4)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters())
for epoch in range(10):
running_loss = 0.0
for i, (inputs, _) in enumerate(trainloader, 0):
outputs = net(inputs)
loss = criterion(outputs, inputs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.data
if i % 200 == 199:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
print('Saving Model...')
torch.save(net.state_dict(), ".autoencoder.pth")
if __name__ == '__main__':
train()
テストコードですね。テストといいつつ訓練のデータを使ってます。
#test.py
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import model
import my_dataset
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def test():
PATH = './.autoencoder.pth'
net = model.Net()
#学習済みデータ読み込み
net.load_state_dict(torch.load(PATH))
transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
trainset = my_dataset.my_dataset("../cifar10_img/airplane",transform)
trainloader = DataLoader(trainset, batch_size=4)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
dataiter = iter(trainloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
outputs = net(images)
imshow(torchvision.utils.make_grid(outputs.data))
if __name__ == '__main__':
test()
結果
こちらサンプルの入力画像です。

モデルにかけてやった出力です。イマイチ感伝わりますでしょうか。。

ちなみにshipをテストしてみた結果ですね。airplane風にはなりませんでしたね。
ship入力

ship出力

ということでcifar10(32x32)レベルでもカラー画像のAutoencoderは難易度が高い、ということでしょうか。。