Понимание и визуализация промежуточных слоев CNN

Мы все знаем, что сверточные нейронные сети (CNN) широко известны и используются для задач классификации изображений, причины их эффективности и конкретные шаблоны, изученные каждым уровнем CNN, остаются несколько загадочными. В этой статье я попытаюсь изучить визуализацию выходных данных каждого сверточного слоя, чтобы получить представление о наблюдаемых закономерностях и пролить свет на замечательную производительность CNN. После обучения модели мы визуализируем, как выглядят промежуточные активации. Эти шаги можно рассматривать как начальные этапы в области объяснимого ИИ и интерпретируемости моделей.

В этой конкретной задаче популярный набор данных MNIST и классическая модель «LeNet5» используются для классификации рукописных цифр MNIST.

  1. Вы можете скачать набор данных с официального сайта -(http://yann.lecun.com/exdb/mnist/) или
  2. напрямую использовать набор данных MNIST, доступный в PyTorch, используя (https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html).

Примечание. В этом задании используется исходная форма 28 x 28.

Сейчас я продемонстрирую код, объясняя его.

Импортируйте все необходимые библиотеки

# Import all the necessary libraries
import numpy as np
from datetime import datetime 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

# check device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

Определить гиперпараметры

# parameters
RANDOM_SEED = 42
LEARNING_RATE = 0.001
BATCH_SIZE = 128
N_EPOCHS = 15

IMG_SIZE = 28
N_CLASSES = 10

# Unpool
unpool = nn.MaxUnpool2d(kernel_size=2)

Получение наборов данных для обучения и проверки с помощью загрузчиков данных

# Transforms
transforms_mnist = transforms.Compose([transforms.Resize((28, 28)),
                                 transforms.ToTensor()])

# Dowload and create Train and Validation Datasets
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms_mnist,
                               download=True)

validation_dataset = datasets.MNIST(root='data', 
                               train=False, 
                               transform=transforms_mnist)

# Create Data Loaders
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

valid_loader = DataLoader(dataset=validation_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False)

Теперь давайте визуализируем несколько образцов изображений из набора данных.

Модель LeNet-5

Теперь, когда у нас есть набор данных и все необходимые компоненты, давайте построим модель LeNet-5. Для этой задачи модель LeNet-5 была модифицирована для достижения современных характеристик (SOTA). Он состоит из двух сверточных слоев, одного слоя MaxPool, трех полносвязных (FC) слоев и функций активации ReLU. Вероятности выходного класса получаются с использованием функции log softmax. Входное изображение преобразуется в 28x28.

# LeNet Model
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        self.out1, self.in1 = None, None 
        self.out2, self.in2 = None, None 

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features=400, out_features=120),
            nn.ReLU(),
            nn.Linear(in_features=120, out_features=84),
            nn.ReLU(),
            nn.Linear(in_features=84, out_features=n_classes),
        )


    def forward(self, x):
        x, self.in1 = self.maxpool(F.relu(self.conv1(x)))
        self.out1 = x
        x, self.in2 = self.maxpool(F.relu(self.conv2(x)))
        self.out2 = x
        x = torch.flatten(x, 1) #Flatten
        logits = self.classifier(x)
        probs = F.softmax(output, dim=1)
        return logits, probs

Контрольные точки, помеченные как «out1», «in1», «out2» и «in2», используются в качестве крючков для доступа к активациям двух конкретных слоев в модели LeNet-5. Эти контрольные точки позволяют нам захватывать и анализировать промежуточные результаты этих слоев во время прямого прохода модели.

Определить функцию потерь и оптимизатор

model = LeNet5(N_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

Обучение и проверка

# Training Loop
def training_loop(model, trainloader):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0
  
  for X, y_actual in trainloader:
    X = X.to(DEVICE)
    y_actual = y_actual.to(DEVICE)
    optimizer.zero_grad()
    y_pred, y_probs = model(X)
    _, predicted_labels = torch.max(y_probs,1)
    loss = loss_fn(y_pred, y_actual)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    total += y_actual.size(0)
    correct += (predicted_labels == y_actual).sum()

  train_loss = running_loss / len(train_loader)
  train_accuracy = 100.0 * correct / total

  print('Training Loss: {:.4f}, Accuracy: {:.2f}%'.format(train_loss, train_accuracy))

  return model, train_accuracy, train_loss

# Validation Loop
def validation_loop(model, test_loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for X, y_actual in test_loader:
          X, y_actual = X.to(DEVICE), y_actual.to(DEVICE)
          y_pred, y_probs = model(X)
          loss = loss_fn(y_pred, y_actual)
          _, predicted_labels = torch.max(y_probs,1)

          running_loss += loss.item()
          total += y_actual.size(0)
          correct += (predicted_labels == y_actual).sum()

        val_loss = running_loss / len(test_loader)
        val_accuracy = 100.0 * correct / total

        print('Validation Loss: {:.4f}, Accuracy: {:.2f}%'.format(val_loss, val_accuracy))

        return model, val_accuracy, val_loss

# Training for 15 Epochs
train_accuracies = []
train_losses = []
val_accuracies = []
val_losses = []

for epoch in range(N_EPOCHS):
    print("----------------")
    print('Epoch:', epoch+1)
    trained_model, train_accuracy, train_loss = training_loop(model, train_loader)
    _, val_accuracy, val_loss =  validation_loop(model, valid_loader)

    train_accuracies.append(train_accuracy)
    train_losses.append(train_loss)
    val_accuracies.append(val_accuracy)
    val_losses.append(val_loss)
----------------
Epoch: 15
Training Loss: 0.0120, Accuracy: 99.60%
Validation Loss: 0.0419, Accuracy: 98.72%

В последнюю эпоху мы видим, что точность обучения составила 99,60%, а потери при обучении — 0,0120. Точность валидации составила 98,72 %, а потеря валидации — 0,0419. Вот графики кривых точности и потерь в зависимости от эпох.

Прогнозы

Вот прогнозы нашей модели —

Преобразование активации:

Далее мы рассмотрим преобразования активации, следуя подходу, изложенному в статье «Визуализация и понимание сверточных сетей» [1]. Этот подход включает в себя сопоставление активаций каждого слоя обратно с пиксельным пространством входного изображения, что позволяет нам визуализировать входные шаблоны, которые способствовали определенным активациям.

Чтобы визуализировать преобразования активации, мы используем деконнет, который присоединен к каждому слою сверточной сети (коннет). Вот обзор процесса:

Входное изображение передается через сеть, в результате чего на каждом слое создаются карты объектов.

  1. Чтобы изучить определенный слой, мы выборочно деактивируем активации всех других слоев и передаем карты объектов соответствующему присоединенному слою деконнета.
  2. В deconvnet мы выполняем распаковку, чтобы определить расположение максимумов в каждой области объединения. Эта информация имеет решающее значение в процессе реконструкции, поскольку помогает восстановить соответствующие местоположения из слоя выше.
  3. После распаковки мы применяем исправление к картам объектов. Поскольку коннет использует выпрямленные линейные единицы (ReLU), что обеспечивает положительные значения на картах признаков, мы еще раз пропускаем реконструированный сигнал через ReLU, чтобы получить достоверные реконструкции положительных признаков.

Этот процесс повторяется итеративно, реконструируя активность в нижнем слое, пока мы не достигнем входного пространства пикселей.

Следуя этому подходу, мы можем эффективно визуализировать входные шаблоны, которые способствовали определенным активациям внутри слоев консети, получая представление об основных представлениях, изученных сетью.

# Function to get transpose of layers
def get_layer_transpose(w):
  layer = torch.transpose(w,0,1)
  layer = torch.transpose(layer,2,3)
  return layer

# Get the weights of each convolutional layer
w2 = model.conv2.weight
w1 = model.conv1.weight

# Get the transposed layers
transposed_kernel2 = get_layer_transpose(w2)
transposed_kernel1 = get_layer_transpose(w1)
# VISUALIZING FIRST CONVOLUTIONAL LAYER
count=1
ROW_IMG = 7
N_ROWS = 7

plt.close('all')

fig = plt.figure()
for index in range(1, N_ROWS):
    plt.subplot(N_ROWS, ROW_IMG, count)
    plt.axis('off')
    plt.imshow(validation_dataset.data[index], cmap='gray')
    count += 1
    
    with torch.no_grad():
        model.eval()
        _, probs = model(validation_dataset[index][0].unsqueeze(0).to(DEVICE))
        for i in range(model.out1.shape[1]):   
          out1 = model.out1.clone()
          
          for j in range(model.out1.shape[1]):
            if i != j:
              out1[0,j] = torch.zeros_like(model.out1[0,0])
          
          out = F.relu(unpool(out1, model.in1))
          tx_kernel1 = get_layer_transpose(w1)
          out = F.conv2d(out, tx_kernel1, stride=1, padding=1)
          tensor = out.detach()
          
          plt.subplot(N_ROWS, ROW_IMG, count)
          plt.axis('off')
          plt.imshow(T.ToPILImage()(tensor[0]), cmap='gray')
          count += 1
    
fig.suptitle('Layer 1 Activation Transform');
plt.show()

# VISUALIZING SECOND CONVOLUTIONAL LAYER
ROW_IMG = 10
N_ROWS = 20
count = 1

fig = plt.figure()
for index in range(1,9):
    plt.subplot(9, model.out2.shape[1]+1, count)
    plt.imshow(validation_dataset.data[index], cmap='gray')
    
    plt.axis('off')
    count += 1
    
    with torch.no_grad():
        model.eval()
        _, probs = model(validation_dataset[index][0].unsqueeze(0).to(DEVICE))
        for i in range(model.out2.shape[1]):
          out2 = model.out2.clone()

          for j in range(model.out2.shape[1]):
            if i != j:
              out2[0,j] = torch.zeros_like(model.out2[0,0])
          out2 = F.relu(unpool(out2, model.in2))
          
          tx_kernel2 = get_layer_transpose(w2)
          tx_kernel1 = get_layer_transpose(w1)
          
          out2 = F.conv2d(out2, tx_kernel2, stride=1, padding=4)
          out1 = F.relu(unpool(out2, model.in1))
          out = F.conv2d(out1, tx_kernel1, stride=1, padding=9)
          tensor = out.detach()
          
          plt.subplot(9, model.out2.shape[1]+1, count)
          plt.axis('off')
          
          plt.imshow(T.ToPILImage()(tensor[0,0]), cmap='gray')
          count += 1
  
fig.suptitle('Layer 2 Activation Transform')
plt.show()

Заключение:

В целом, преобразования активации и деконволюция предоставляют мощный набор инструментов для исследования и анализа внутренней работы CNN, проливая свет на сложные шаблоны и функции, которые обеспечивают их впечатляющую производительность в таких задачах, как классификация изображений, обнаружение объектов и семантическая сегментация.

Использованная литература:

  1. Мэтью Д. Зейлер, Роб Фергус: «Визуализация и понимание сверточных сетей», 2013 г.; архив: 1311.2901.
  2. Ле Кун и др.: «Сравнение алгоритмов обучения распознаванию рукописных цифр», 1995; ICANN