Pytorchでカスタムデータセット作成してみる(CSVファイル、画像読み込み)
今回はpytorchのcifar10画像分類のチュートリアル実施からちょっと進んで、カスタムデータセットをpytorchで読み込めるように挑戦してみます。前回のpytorch cifar10画像分類チュートリアル実施してみた記事はこちらです。
チュートリアルでは以下のようにほぼ何もしなくてもcifar10データがよみこまれます。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
実際の現場ではこんな風にラクにはならずに、用途に応じでデータセットを作成する必要があるかと思います。
ここではよく想定されそうな、CSVファイルに画像ファイルと正解ラベルがあるようなファイルを想定してpytorchへの読み込み(dataset化,dataloader化)を試してみます。
csvファイル作成
画像はcifar10画像を借用してみます。cifar10からの個別の画像取り出しはググって頂くとして、こんな感じのCSVファイルを作成してみました。airplaneカテゴリの画像にもいくつか種類があるようですので、さらに細かく分類できるか試してみる想定です。airbus_sは「0」,airliner_sは「1」,attack_aircraft_sは「2」としときます。
##airplane.csv
path,label
./cif10/airplane/airbus_s_000012.png,0
./cif10/airplane/airbus_s_000013.png,0
./cif10/airplane/airbus_s_000024.png,0
./cif10/airplane/airbus_s_000107.png,0
./cif10/airplane/airbus_s_000119.png,0
./cif10/airplane/airliner_s_000013.png,1
./cif10/airplane/airliner_s_000020.png,1
./cif10/airplane/airliner_s_000050.png,1
./cif10/airplane/airliner_s_000051.png,1
./cif10/airplane/airliner_s_000088.png,1
./cif10/airplane/attack_aircraft_s_000003.png,2
./cif10/airplane/attack_aircraft_s_000005.png,2
./cif10/airplane/attack_aircraft_s_000011.png,2
./cif10/airplane/attack_aircraft_s_000029.png,2
./cif10/airplane/attack_aircraft_s_000037.png,2
・・・・以下データが続く
カスタムデータセットクラス作成
上で作成したairplaneと同じフォルダにカスタムデータセットクラスmy_dataset.pyを作ってみます。作り方はDatasetクラスを継承したクラスを作って、「__init__」,「__getitem__」,「__len__」 などの関数を実装すればOKですね。
「__getitem__」 のところで、画像と正解ラベル以外にファイル名も返すようにしてます。
またtransforms.Resize((32,32))で例えばサイズが違う画像が来てもこのサイズにして落とし込めます。cifar10の画像はすべて32×32にはなっていますが。
transforms.ToTensor()は必須ですね。ないとダメです。動きません。
#my_dataset.py
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision import transforms
class my_dataset(Dataset):
def __init__(self, csv_path,transform=None):
#csvファイル読み込み。
df = pd.read_csv(csv_path)
image_paths = df['path']
labels = df['label']
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __getitem__(self, index):
path = self.image_paths[index]
#画像読み込み。
img = Image.open(path)
#transform事前処理実施
if self.transform is not None:
img = self.transform(img)
label=self.labels[index]
image_path=self.image_paths[index]
return img,label,image_path
def __len__(self):
#データ数を返す
return len(self.image_paths)
if __name__ == '__main__':
#transformで32x32画素に変換して、テンソル化。
transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
#データセット作成
dataset = my_dataset("./airplane.csv",transform)
#dataloader化
dataloader = DataLoader(dataset, batch_size=4)
#データローダの中身確認
for img,label ,image_path in dataloader:
print('label=',label)
print('image_path=',image_path)
print('img.shape=',img.shape)
transformsはいろいろ便利すぎていまいち使いこなせないですね。。別途実験したいと思います。
データローダの出力確認
上記のコードを実行してdataloaderの中身を確認してみます。
$python my_dataset.py
とあるループでの実行結果はこんなですね。
label= tensor([0, 1, 1, 1])
image_path= ('./cif10/airplane/airbus_s_000119.png', './cif10/airplane/airliner_s_000013.png', './cif10/airplane/airliner_s_000020.png', './cif10/airplane/airliner_s_000050.png')
img.shape= torch.Size([4, 3, 32, 32])
batch_size=4を指定してますので、データローダのデータ一つは4つデータが結合したデータになってるのがわかりますね。tensor([0, 1, 1, 1])は正解ラベルの配列ですね。img.shape= torch.Size([4, 3, 32, 32])も32×32のカラー画像が4つあるよ、ってことですかね。
まとめ
このデータセットを使って学習を実施してみました。が残念ながら学習が収束しなかったので、airplaneからさらにそのサブカテゴリまで判別するのは簡単ではない感じですかね。参考までにモデル定義、学習実施のファイルも書いときます。
#model.py
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
#train.py
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import model
import my_dataset
def train():
net = model.Net()
transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
trainset = my_dataset.my_dataset("./airplane.csv",transform)
trainloader = DataLoader(trainset, batch_size=2)
#損失関数(criterion)、最適化関数を定義(optimizer)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels ,image_path = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 10 == 9: # print every 10 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 10))
running_loss = 0.0
print('Finished Training')
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
if __name__ == '__main__':
train()