VGG16でアリとハチのクラス分類をファインチューニングして、その中間層の中身をcnn-visualizationsを使って可視化してみました。
ファインチューニング
まずはPyTorchでVGG16でファインチューニングをしてみます。対象のデータはKaggleにあるアリとハチの2クラス分類問題でやってみます。このデータはtrainとvalにデータがフォルダで分かれていてその中にあるアリ(ants)とハチ(bees)のフォルダに画像がたくさんあるという形になっています。
PyTorchに用意されている事前学習済みのVGG16をベースに学習を実施します。
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms, models import torchvision from PIL import Image # 学習済みモデルの取得 model = models.vgg16(pretrained=True) # 全結合層の変更(最終層の出力を2にする) model.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 2), ) # GPUを使う device = torch.device('cuda') model = model.to(device) # バッチサイズ batchsize = 8 # 学習データ traindata = torchvision.datasets.ImageFolder(root='./hymenoptera_data/train/', transform=transforms.Compose([ transforms.Resize((224,224), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) ) trainloader = torch.utils.data.DataLoader(dataset=traindata, batch_size=batchsize, shuffle=True) # テストデータ testdata = torchvision.datasets.ImageFolder(root='./hymenoptera_data/val/', transform=transforms.Compose([ transforms.Resize((224,224), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) ) testloader = torch.utils.data.DataLoader(dataset=testdata, batch_size=batchsize, shuffle=True) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 50エポック学習 for epoch in range(50): running_loss = 0.0 correct_num = 0 total_num = 0 for i, (data, target) in enumerate(trainloader): inputs, labels = data.to(device), target.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) predicted = torch.max(outputs.data, 1)[1] correct_num_temp = (predicted==labels).sum() correct_num += correct_num_temp.item() total_num += batchsize loss.backward() optimizer.step() running_loss += loss.item() # 経過の出力 print('epoch:%d loss: %.3f acc: %.3f' % (epoch + 1, running_loss / 100, correct_num*100/total_num)) # モデルの保存 torch.save(model.state_dict(), 'sample.pt') # テストデータ精度出力 model.eval() correct_num = 0 total_num = 0 for i, (data, target) in enumerate(testloader): inputs, labels = data.to(device), target.to(device) outputs = model(inputs) predicted = torch.max(outputs.data, 1)[1] correct_num_temp = (predicted == labels).sum() correct_num += correct_num_temp.item() total_num += batchsize print('test acc: %.3f' % (correct_num * 100 / total_num))
データローダでNormalizeしている箇所に0.485や0.229など見慣れないマジックナンバーがあると思いますが、これらはバッチノーマリゼーションの一種でこの値自体はImageNetの平均、標準偏差の値のようです。詳細は下記リンクを参考にしてください。
[PyTorch] 1. Transform, ImageFolder, DataLoader
【GIF】初心者のためのCNNからバッチノーマライゼーションとその仲間たちまでの解説
結果
epoch:1 loss: 0.165 acc: 72.177
epoch:2 loss: 0.118 acc: 82.661
epoch:3 loss: 0.191 acc: 71.774
epoch:4 loss: 0.088 acc: 89.113
epoch:5 loss: 0.018 acc: 97.177
epoch:6 loss: 0.002 acc: 98.387
~(省略)~
epoch:49 loss: 0.000 acc: 98.387
epoch:50 loss: 0.000 acc: 98.387
test acc: 88.125
かなり早い段階で収束してましたが、テストデータの精度は88%とそれほどよくないです。
今回は精度云々より可視化がメインなのでとりあえず目をつむって次に進みます。
cnn-visualizationsで可視化する
先ほどファインチューニングしたした結果(sample.pt)についてcnn-visualizationsを使って中間層の様子を可視化してみます。詳細は下記リンクを参照していただきたいですが、Grad-Camなど他の可視化手法も色々使えるようです。
元のソースを下記のように修正してファインチューニングしたモデルの中間層を可視化できるようにしてみました。途中sample.ptという自前のモデルを読み込んでいます。cnn_layerに可視化したいレイヤーを、filter_posに中間層のチャンネルを指定します。この場合512個チャンネルがあるので0-511まで選べます。forループですべて実行するのもありだと思います。
cnn_layer_visualization.py
~(省略)~ if __name__ == '__main__': cnn_layer = 17 filter_pos = 5 # 学習済みモデルを読み込み model = models.vgg16(pretrained=False) model.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 2) ) model.load_state_dict(torch.load('sample.pt')) pretrained_model = model.features layer_vis = CNNLayerVisualization(pretrained_model, cnn_layer, filter_pos) # Layer visualization with pytorch hooks layer_vis.visualise_layer_with_hooks() # Layer visualization without pytorch hooks # layer_vis.visualise_layer_without_hooks()
レイヤーについてはVGG16の層を出力した下記参考にしてください。
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
(※classifierの(6)のout_featuresは今回1000ではなく2)
結果0-5のチャンネルを可視化すると以下のようになりました。
恐らく左上や左下、右下あたりは目玉などの丸い特徴をとらえていて、真ん中の上は毛並みのような特徴をとらえているのだと思います。これを見ているだけでも結構面白いですが、これらが分類に対してどれくらい影響しているのか知るにはちょっと難しいと思いました。(512チャンネルもあるわけですしなおさらそのような気がします)。
また、実はファインチューニングしないでこの可視化をしてみましたが、出てきた画像はファインチューニングしたときの結果とあまり変わりませんでした。512チャンネルもあるので中身はほぼ変わる必要ないということなのだろうかと思います。
まとめ
アリとハチの2クラス問題をテーマにファインチューニング、CNN中間層の可視化を試してみました。