CIFAR10でAutoencoder(CNNで)試してみる(Pytorchで)

ちょっと前にCIFAR10でカスタムデータセットのクラスを作ってみましたが、今回はそれを使ってCIFAR10のオートエンコーダの実験をしてみました。MNISTではどうやってもうまくいきそうな気がしますので、CIFAR10で実験です。

フォルダ構成はこんなと想定してください。cifar10_imgフォルダ以下のカテゴリ名のフォルダ以下に、PNGファイルがたくさん保存されています。

今回はairplaneについてのCNNを利用したAutoencoderができるか試してみます。

コード

まずはモデルです。ちなみにいくつかモデル試してみましたが、どのモデルも結果としてはイマイチであまり違いが出ませんでしたね。。ちなみにこれはSony NNCのサンプルプロジェクトにインスパイアされたモデルです。

#model.py

import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 8, 5, stride=2, padding=2),            
            nn.PReLU(),
            nn.Conv2d(8, 8, 5, stride=2, padding=2),          
            nn.PReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 8, 6, stride=2, padding=2),  
            nn.PReLU(),
	  nn.ConvTranspose2d(8, 3, 6, stride=2, padding=2),  
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

データセットの取り方はこんなですね。airplaneフォルダの画像を読み込むだけですね。ラベルはAutoencoderなのでいらないですが、とりあえず0にしときます。

#my_dataset.py
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision import transforms
import glob
import os

class my_dataset(Dataset):
    def __init__(self, img_path,transform=None):
        
        image_paths = glob.glob(img_path + '/*.png')
        labels = os.path.basename(img_path)

        self.image_paths = image_paths
        self.labels = 0 
        self.transform = transform

    def __getitem__(self, index):
        path = self.image_paths[index] 
        #画像読み込み。
        img = Image.open(path)
        #transform事前処理実施
        if self.transform is not None:
            img = self.transform(img)

        label=self.labels
        image_path=self.image_paths[index]
        return img,label

    def __len__(self):
        return len(self.image_paths)

if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
    #データセット作成
    dataset = my_dataset("../cifar10_img/airplane",transform)
    #dataloader化
    dataloader = DataLoader(dataset, batch_size=4)

    #データローダの中身確認
    for img,label in dataloader:
        print('label=',label)
        print('image_path=',img.shape)

訓練コードですね。Autoencoder時の損失関数は 「nn.BCELoss()」ですね。

#train.py
import torch
import torchvision
import torchvision.transforms as transforms

import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import model
import my_dataset

def train():
    net = model.Net()
    transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])

    trainset = my_dataset.my_dataset("../cifar10_img/airplane",transform)
    trainloader = DataLoader(trainset, batch_size=4)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(net.parameters())

    for epoch in range(10):
        running_loss = 0.0
        for i, (inputs, _) in enumerate(trainloader, 0):
            outputs = net(inputs)
            loss = criterion(outputs, inputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.data
            if i % 200 == 199:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
    print('Finished Training')
    print('Saving Model...')
  
    torch.save(net.state_dict(), ".autoencoder.pth")
  
if __name__ == '__main__':
    train()

テストコードですね。テストといいつつ訓練のデータを使ってます。

#test.py

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

import model
import my_dataset
import matplotlib.pyplot as plt
import numpy as np


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def test():
    PATH = './.autoencoder.pth'
    net = model.Net()
    #学習済みデータ読み込み
    net.load_state_dict(torch.load(PATH))
    transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])

    trainset = my_dataset.my_dataset("../cifar10_img/airplane",transform)
    trainloader = DataLoader(trainset, batch_size=4)

    classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    dataiter = iter(trainloader)
    images, labels = dataiter.next()

    # print images
    imshow(torchvision.utils.make_grid(images))
    outputs = net(images)
    imshow(torchvision.utils.make_grid(outputs.data))

if __name__ == '__main__':
    test()

結果

こちらサンプルの入力画像です。

airplane_ae前

モデルにかけてやった出力です。イマイチ感伝わりますでしょうか。。

airplane_ae後

ちなみにshipをテストしてみた結果ですね。airplane風にはなりませんでしたね。

ship入力

ship_ae前

ship出力

ship_ae後

ということでcifar10(32x32)レベルでもカラー画像のAutoencoderは難易度が高い、ということでしょうか。。