Понимание и визуализация промежуточных слоев CNN
Мы все знаем, что сверточные нейронные сети (CNN) широко известны и используются для задач классификации изображений, причины их эффективности и конкретные шаблоны, изученные каждым уровнем CNN, остаются несколько загадочными. В этой статье я попытаюсь изучить визуализацию выходных данных каждого сверточного слоя, чтобы получить представление о наблюдаемых закономерностях и пролить свет на замечательную производительность CNN. После обучения модели мы визуализируем, как выглядят промежуточные активации. Эти шаги можно рассматривать как начальные этапы в области объяснимого ИИ и интерпретируемости моделей.
В этой конкретной задаче популярный набор данных MNIST и классическая модель «LeNet5» используются для классификации рукописных цифр MNIST.
- Вы можете скачать набор данных с официального сайта -(http://yann.lecun.com/exdb/mnist/) или
- напрямую использовать набор данных MNIST, доступный в PyTorch, используя (https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html).
Примечание. В этом задании используется исходная форма 28 x 28.
Сейчас я продемонстрирую код, объясняя его.
Импортируйте все необходимые библиотеки
# Import all the necessary libraries import numpy as np from datetime import datetime import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import torchvision.transforms as T from torchvision import datasets, transforms import matplotlib.pyplot as plt # check device DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' print(DEVICE)
Определить гиперпараметры
# parameters RANDOM_SEED = 42 LEARNING_RATE = 0.001 BATCH_SIZE = 128 N_EPOCHS = 15 IMG_SIZE = 28 N_CLASSES = 10 # Unpool unpool = nn.MaxUnpool2d(kernel_size=2)
Получение наборов данных для обучения и проверки с помощью загрузчиков данных
# Transforms transforms_mnist = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()]) # Dowload and create Train and Validation Datasets train_dataset = datasets.MNIST(root='data', train=True, transform=transforms_mnist, download=True) validation_dataset = datasets.MNIST(root='data', train=False, transform=transforms_mnist) # Create Data Loaders train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
Теперь давайте визуализируем несколько образцов изображений из набора данных.
Модель LeNet-5
Теперь, когда у нас есть набор данных и все необходимые компоненты, давайте построим модель LeNet-5. Для этой задачи модель LeNet-5 была модифицирована для достижения современных характеристик (SOTA). Он состоит из двух сверточных слоев, одного слоя MaxPool, трех полносвязных (FC) слоев и функций активации ReLU. Вероятности выходного класса получаются с использованием функции log softmax. Входное изображение преобразуется в 28x28.
# LeNet Model class LeNet5(nn.Module): def __init__(self, n_classes): super(LeNet5, self).__init__() self.out1, self.in1 = None, None self.out2, self.in2 = None, None self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2) self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1) self.classifier = nn.Sequential( nn.Linear(in_features=400, out_features=120), nn.ReLU(), nn.Linear(in_features=120, out_features=84), nn.ReLU(), nn.Linear(in_features=84, out_features=n_classes), ) def forward(self, x): x, self.in1 = self.maxpool(F.relu(self.conv1(x))) self.out1 = x x, self.in2 = self.maxpool(F.relu(self.conv2(x))) self.out2 = x x = torch.flatten(x, 1) #Flatten logits = self.classifier(x) probs = F.softmax(output, dim=1) return logits, probs
Контрольные точки, помеченные как «out1», «in1», «out2» и «in2», используются в качестве крючков для доступа к активациям двух конкретных слоев в модели LeNet-5. Эти контрольные точки позволяют нам захватывать и анализировать промежуточные результаты этих слоев во время прямого прохода модели.
Определить функцию потерь и оптимизатор
model = LeNet5(N_CLASSES).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) loss_fn = nn.CrossEntropyLoss()
Обучение и проверка
# Training Loop def training_loop(model, trainloader): model.train() running_loss = 0.0 correct = 0 total = 0 for X, y_actual in trainloader: X = X.to(DEVICE) y_actual = y_actual.to(DEVICE) optimizer.zero_grad() y_pred, y_probs = model(X) _, predicted_labels = torch.max(y_probs,1) loss = loss_fn(y_pred, y_actual) loss.backward() optimizer.step() running_loss += loss.item() total += y_actual.size(0) correct += (predicted_labels == y_actual).sum() train_loss = running_loss / len(train_loader) train_accuracy = 100.0 * correct / total print('Training Loss: {:.4f}, Accuracy: {:.2f}%'.format(train_loss, train_accuracy)) return model, train_accuracy, train_loss # Validation Loop def validation_loop(model, test_loader): model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for X, y_actual in test_loader: X, y_actual = X.to(DEVICE), y_actual.to(DEVICE) y_pred, y_probs = model(X) loss = loss_fn(y_pred, y_actual) _, predicted_labels = torch.max(y_probs,1) running_loss += loss.item() total += y_actual.size(0) correct += (predicted_labels == y_actual).sum() val_loss = running_loss / len(test_loader) val_accuracy = 100.0 * correct / total print('Validation Loss: {:.4f}, Accuracy: {:.2f}%'.format(val_loss, val_accuracy)) return model, val_accuracy, val_loss # Training for 15 Epochs train_accuracies = [] train_losses = [] val_accuracies = [] val_losses = [] for epoch in range(N_EPOCHS): print("----------------") print('Epoch:', epoch+1) trained_model, train_accuracy, train_loss = training_loop(model, train_loader) _, val_accuracy, val_loss = validation_loop(model, valid_loader) train_accuracies.append(train_accuracy) train_losses.append(train_loss) val_accuracies.append(val_accuracy) val_losses.append(val_loss) ---------------- Epoch: 15 Training Loss: 0.0120, Accuracy: 99.60% Validation Loss: 0.0419, Accuracy: 98.72%
В последнюю эпоху мы видим, что точность обучения составила 99,60%, а потери при обучении — 0,0120. Точность валидации составила 98,72 %, а потеря валидации — 0,0419. Вот графики кривых точности и потерь в зависимости от эпох.
Прогнозы
Вот прогнозы нашей модели —
Преобразование активации:
Далее мы рассмотрим преобразования активации, следуя подходу, изложенному в статье «Визуализация и понимание сверточных сетей» [1]. Этот подход включает в себя сопоставление активаций каждого слоя обратно с пиксельным пространством входного изображения, что позволяет нам визуализировать входные шаблоны, которые способствовали определенным активациям.
Чтобы визуализировать преобразования активации, мы используем деконнет, который присоединен к каждому слою сверточной сети (коннет). Вот обзор процесса:
Входное изображение передается через сеть, в результате чего на каждом слое создаются карты объектов.
- Чтобы изучить определенный слой, мы выборочно деактивируем активации всех других слоев и передаем карты объектов соответствующему присоединенному слою деконнета.
- В deconvnet мы выполняем распаковку, чтобы определить расположение максимумов в каждой области объединения. Эта информация имеет решающее значение в процессе реконструкции, поскольку помогает восстановить соответствующие местоположения из слоя выше.
- После распаковки мы применяем исправление к картам объектов. Поскольку коннет использует выпрямленные линейные единицы (ReLU), что обеспечивает положительные значения на картах признаков, мы еще раз пропускаем реконструированный сигнал через ReLU, чтобы получить достоверные реконструкции положительных признаков.
Этот процесс повторяется итеративно, реконструируя активность в нижнем слое, пока мы не достигнем входного пространства пикселей.
Следуя этому подходу, мы можем эффективно визуализировать входные шаблоны, которые способствовали определенным активациям внутри слоев консети, получая представление об основных представлениях, изученных сетью.
# Function to get transpose of layers def get_layer_transpose(w): layer = torch.transpose(w,0,1) layer = torch.transpose(layer,2,3) return layer # Get the weights of each convolutional layer w2 = model.conv2.weight w1 = model.conv1.weight # Get the transposed layers transposed_kernel2 = get_layer_transpose(w2) transposed_kernel1 = get_layer_transpose(w1) # VISUALIZING FIRST CONVOLUTIONAL LAYER count=1 ROW_IMG = 7 N_ROWS = 7 plt.close('all') fig = plt.figure() for index in range(1, N_ROWS): plt.subplot(N_ROWS, ROW_IMG, count) plt.axis('off') plt.imshow(validation_dataset.data[index], cmap='gray') count += 1 with torch.no_grad(): model.eval() _, probs = model(validation_dataset[index][0].unsqueeze(0).to(DEVICE)) for i in range(model.out1.shape[1]): out1 = model.out1.clone() for j in range(model.out1.shape[1]): if i != j: out1[0,j] = torch.zeros_like(model.out1[0,0]) out = F.relu(unpool(out1, model.in1)) tx_kernel1 = get_layer_transpose(w1) out = F.conv2d(out, tx_kernel1, stride=1, padding=1) tensor = out.detach() plt.subplot(N_ROWS, ROW_IMG, count) plt.axis('off') plt.imshow(T.ToPILImage()(tensor[0]), cmap='gray') count += 1 fig.suptitle('Layer 1 Activation Transform'); plt.show()
# VISUALIZING SECOND CONVOLUTIONAL LAYER ROW_IMG = 10 N_ROWS = 20 count = 1 fig = plt.figure() for index in range(1,9): plt.subplot(9, model.out2.shape[1]+1, count) plt.imshow(validation_dataset.data[index], cmap='gray') plt.axis('off') count += 1 with torch.no_grad(): model.eval() _, probs = model(validation_dataset[index][0].unsqueeze(0).to(DEVICE)) for i in range(model.out2.shape[1]): out2 = model.out2.clone() for j in range(model.out2.shape[1]): if i != j: out2[0,j] = torch.zeros_like(model.out2[0,0]) out2 = F.relu(unpool(out2, model.in2)) tx_kernel2 = get_layer_transpose(w2) tx_kernel1 = get_layer_transpose(w1) out2 = F.conv2d(out2, tx_kernel2, stride=1, padding=4) out1 = F.relu(unpool(out2, model.in1)) out = F.conv2d(out1, tx_kernel1, stride=1, padding=9) tensor = out.detach() plt.subplot(9, model.out2.shape[1]+1, count) plt.axis('off') plt.imshow(T.ToPILImage()(tensor[0,0]), cmap='gray') count += 1 fig.suptitle('Layer 2 Activation Transform') plt.show()
Заключение:
В целом, преобразования активации и деконволюция предоставляют мощный набор инструментов для исследования и анализа внутренней работы CNN, проливая свет на сложные шаблоны и функции, которые обеспечивают их впечатляющую производительность в таких задачах, как классификация изображений, обнаружение объектов и семантическая сегментация.
Использованная литература:
- Мэтью Д. Зейлер, Роб Фергус: «Визуализация и понимание сверточных сетей», 2013 г.; архив: 1311.2901.
- Ле Кун и др.: «Сравнение алгоритмов обучения распознаванию рукописных цифр», 1995; ICANN