学習目標
この章を読むことで、以下を習得できます:
- ✅ Siamese Networksによるペア学習とContrastive Lossを理解できる
- ✅ Prototypical Networksのプロトタイプベース分類を実装できる
- ✅ Matching NetworksのAttention機構を活用した分類を実装できる
- ✅ Relation Networksの学習可能な距離メトリックを理解できる
- ✅ Few-Shot学習手法の比較実験を設計・実施できる
1. Siamese Networks
1.1 ペア学習の原理
Siamese Networks(シャムネットワーク)は、2つの入力を同じネットワーク(重み共有)で処理し、その類似度を学習するアーキテクチャです。Few-Shot学習において、サンプル間の関係性を直接学習する基本的な手法です。
主要な特徴:
- 重み共有: 2つの入力に同じネットワークを適用することで、一貫した特徴空間を学習
- ペア単位学習: 2つのサンプルが同じクラスか異なるクラスかを直接学習
- 計量学習: 意味的に類似したサンプルは近く、異なるサンプルは遠く配置
1.2 Contrastive Loss
Contrastive Lossは、同じクラスのペアは近く、異なるクラスのペアは遠くなるように学習する損失関数です。
数式定義:
$$ \mathcal{L}(x_1, x_2, y) = y \cdot d(x_1, x_2)^2 + (1-y) \cdot \max(0, m - d(x_1, x_2))^2 $$ここで:
- $d(x_1, x_2)$ はユークリッド距離
- $y \in \{0, 1\}$ はラベル(1=同じクラス、0=異なるクラス)
- $m$ はマージン(異なるクラス間の最小距離)
1.3 画像ペアでの類似度学習
import torch
import torch.nn as nn
import torch.nn.functional as F
class SiameseNetwork(nn.Module):
"""Siamese Network実装"""
def __init__(self, input_channels=3, embedding_dim=128):
super(SiameseNetwork, self).__init__()
# 共有される特徴抽出器
self.encoder = nn.Sequential(
# Conv Block 1
nn.Conv2d(input_channels, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Conv Block 2
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Conv Block 3
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Flatten(),
)
# 全結合層で埋め込み空間へ
self.fc = nn.Sequential(
nn.Linear(256 * 8 * 8, 512),
nn.ReLU(inplace=True),
nn.Linear(512, embedding_dim)
)
def forward_one(self, x):
"""1つの入力を埋め込み空間へ変換"""
x = self.encoder(x)
x = self.fc(x)
return F.normalize(x, p=2, dim=1) # L2正規化
def forward(self, x1, x2):
"""ペア入力を処理"""
emb1 = self.forward_one(x1)
emb2 = self.forward_one(x2)
return emb1, emb2
class ContrastiveLoss(nn.Module):
"""Contrastive Loss実装"""
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, emb1, emb2, label):
"""
Args:
emb1, emb2: 埋め込みベクトル (batch_size, embedding_dim)
label: ラベル (1=同じクラス, 0=異なるクラス)
"""
# ユークリッド距離
distance = F.pairwise_distance(emb1, emb2, p=2)
# Contrastive Loss
loss_positive = label * torch.pow(distance, 2)
loss_negative = (1 - label) * torch.pow(
torch.clamp(self.margin - distance, min=0.0), 2
)
loss = torch.mean(loss_positive + loss_negative)
return loss
# 学習例
def train_siamese(model, train_loader, num_epochs=10):
"""Siamese Networkの学習"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = ContrastiveLoss(margin=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, (img1, img2, labels) in enumerate(train_loader):
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
# Forward
emb1, emb2 = model(img1, img2)
loss = criterion(emb1, emb2, labels.float())
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
# 使用例
model = SiameseNetwork(input_channels=3, embedding_dim=128)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
2. Prototypical Networks
2.1 プロトタイプ(クラス中心)の計算
Prototypical Networksは、各クラスの「プロトタイプ」(代表的な埋め込みベクトル)を計算し、新しいサンプルを最も近いプロトタイプのクラスに分類します。
平均] E2 --> PA E3 --> PB[プロトタイプB
平均] E4 --> PB Q[Query] --> EQ[エンコーダ] EQ --> D[距離計算] PA --> D PB --> D D --> C[分類] style PA fill:#9d4edd style PB fill:#9d4edd style D fill:#3182ce
プロトタイプの定義:
$$ c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\theta(x_i) $$ここで:
- $c_k$ はクラス$k$のプロトタイプ
- $S_k$ はクラス$k$のサポートセット
- $f_\theta$ はエンコーダネットワーク
2.2 ユークリッド距離ベースの分類
クエリサンプル$x$のクラス確率はsoftmaxで計算されます:
$$ P(y=k|x) = \frac{\exp(-d(f_\theta(x), c_k))}{\sum_{k'} \exp(-d(f_\theta(x), c_{k'}))} $$2.3 PyTorch実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class PrototypicalNetwork(nn.Module):
"""Prototypical Network実装"""
def __init__(self, input_channels=3, hidden_dim=64):
super(PrototypicalNetwork, self).__init__()
# 特徴抽出器(4層CNNブロック)
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.encoder = nn.Sequential(
conv_block(input_channels, hidden_dim),
conv_block(hidden_dim, hidden_dim),
conv_block(hidden_dim, hidden_dim),
conv_block(hidden_dim, hidden_dim),
nn.Flatten()
)
def forward(self, support_images, support_labels, query_images, n_way, k_shot):
"""
Args:
support_images: (n_way * k_shot, C, H, W)
support_labels: (n_way * k_shot,)
query_images: (n_query, C, H, W)
n_way: クラス数
k_shot: クラスあたりのサンプル数
"""
# サポートセットとクエリセットの埋め込み
support_embeddings = self.encoder(support_images)
query_embeddings = self.encoder(query_images)
# プロトタイプの計算(各クラスの平均)
prototypes = self.compute_prototypes(
support_embeddings, support_labels, n_way
)
# クエリとプロトタイプ間の距離を計算
distances = self.euclidean_distance(query_embeddings, prototypes)
# 負の距離をlogitsとして使用
logits = -distances
return logits
def compute_prototypes(self, embeddings, labels, n_way):
"""各クラスのプロトタイプを計算"""
prototypes = torch.zeros(n_way, embeddings.size(1), device=embeddings.device)
for k in range(n_way):
# クラスkに属するサンプルのマスク
mask = (labels == k)
# クラスkのサンプルの平均を計算
prototypes[k] = embeddings[mask].mean(dim=0)
return prototypes
def euclidean_distance(self, x, y):
"""
ユークリッド距離の計算
Args:
x: (n_query, d)
y: (n_way, d)
Returns:
distances: (n_query, n_way)
"""
n = x.size(0)
m = y.size(0)
d = x.size(1)
# ブロードキャストで効率的に計算
x = x.unsqueeze(1).expand(n, m, d) # (n, m, d)
y = y.unsqueeze(0).expand(n, m, d) # (n, m, d)
return torch.pow(x - y, 2).sum(2) # (n, m)
def train_prototypical(model, train_loader, num_epochs=100, n_way=5, k_shot=1):
"""Prototypical Networkの学習"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_acc = 0
for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(train_loader):
support_imgs = support_imgs.to(device)
support_labels = support_labels.to(device)
query_imgs = query_imgs.to(device)
query_labels = query_labels.to(device)
# Forward
logits = model(support_imgs, support_labels, query_imgs, n_way, k_shot)
loss = criterion(logits, query_labels)
# 精度計算
pred = logits.argmax(dim=1)
acc = (pred == query_labels).float().mean()
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_acc += acc.item()
avg_loss = total_loss / len(train_loader)
avg_acc = total_acc / len(train_loader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")
# 使用例
model = PrototypicalNetwork(input_channels=3, hidden_dim=64)
print(f"Model architecture:\n{model}")
3. Matching Networks
3.1 Attention機構の活用
Matching Networksは、クエリサンプルとサポートセットの各サンプル間でAttention機構を使用し、加重平均でクラス確率を計算します。これにより、サポートセット全体のコンテキストを考慮した分類が可能になります。
重み1] EQ --> A2[Attention
重み2] EQ --> A3[Attention
重み3] ES1 --> A1 ES2 --> A2 ES3 --> A3 A1 --> W[加重平均] A2 --> W A3 --> W W --> P[予測] style EQ fill:#9d4edd style W fill:#3182ce
3.2 Full Context Embeddings
Matching Networksの重要な特徴は、サポートセット全体のコンテキストを考慮した埋め込みを生成することです。これはLSTMなどの系列モデルで実現されます。
Attention重みの計算:
$$ a(x, x_i) = \frac{\exp(c(\hat{x}, \hat{x}_i))}{\sum_j \exp(c(\hat{x}, \hat{x}_j))} $$予測分布:
$$ P(y|x, S) = \sum_{i=1}^k a(x, x_i) y_i $$3.3 実装と評価
import torch
import torch.nn as nn
import torch.nn.functional as F
class MatchingNetwork(nn.Module):
"""Matching Network実装"""
def __init__(self, input_channels=3, hidden_dim=64, lstm_layers=1):
super(MatchingNetwork, self).__init__()
# 特徴抽出器(CNNエンコーダ)
self.encoder = nn.Sequential(
self._conv_block(input_channels, hidden_dim),
self._conv_block(hidden_dim, hidden_dim),
self._conv_block(hidden_dim, hidden_dim),
self._conv_block(hidden_dim, hidden_dim),
)
# 埋め込み次元を計算
self.embedding_dim = hidden_dim * 5 * 5
# Full Context Embeddings用のLSTM
self.lstm = nn.LSTM(
input_size=self.embedding_dim,
hidden_size=self.embedding_dim,
num_layers=lstm_layers,
bidirectional=True,
batch_first=True
)
# 双方向LSTMの出力を元の次元に変換
self.fc = nn.Linear(self.embedding_dim * 2, self.embedding_dim)
def _conv_block(self, in_channels, out_channels):
"""CNNブロック"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def encode(self, x):
"""画像を埋め込みベクトルに変換"""
batch_size = x.size(0)
x = self.encoder(x)
x = x.view(batch_size, -1)
return x
def full_context_embeddings(self, embeddings):
"""
LSTMでサポートセット全体のコンテキストを考慮
Args:
embeddings: (batch_size, seq_len, embedding_dim)
"""
output, _ = self.lstm(embeddings)
output = self.fc(output)
return output
def attention(self, query_emb, support_emb):
"""
Attention重みを計算
Args:
query_emb: (n_query, embedding_dim)
support_emb: (n_support, embedding_dim)
Returns:
attention_weights: (n_query, n_support)
"""
# コサイン類似度を計算
query_norm = F.normalize(query_emb, p=2, dim=1)
support_norm = F.normalize(support_emb, p=2, dim=1)
similarities = torch.mm(query_norm, support_norm.t())
# Softmaxでattention重みに変換
attention_weights = F.softmax(similarities, dim=1)
return attention_weights
def forward(self, support_images, support_labels, query_images, n_way):
"""
Args:
support_images: (n_way * k_shot, C, H, W)
support_labels: (n_way * k_shot,) one-hot encoded
query_images: (n_query, C, H, W)
"""
# 埋め込みを計算
support_emb = self.encode(support_images) # (n_support, emb_dim)
query_emb = self.encode(query_images) # (n_query, emb_dim)
# Full Context Embeddings(サポートセットのみ)
support_emb_context = self.full_context_embeddings(
support_emb.unsqueeze(0) # (1, n_support, emb_dim)
).squeeze(0) # (n_support, emb_dim)
# Attention重みを計算
attention_weights = self.attention(query_emb, support_emb_context)
# One-hotラベルに変換
support_labels_one_hot = F.one_hot(support_labels, n_way).float()
# Attention重み付き予測
predictions = torch.mm(attention_weights, support_labels_one_hot)
return predictions
# 学習関数
def train_matching(model, train_loader, num_epochs=100, n_way=5):
"""Matching Networkの学習"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_acc = 0
for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(train_loader):
support_imgs = support_imgs.to(device)
support_labels = support_labels.to(device)
query_imgs = query_imgs.to(device)
query_labels = query_labels.to(device)
# Forward
predictions = model(support_imgs, support_labels, query_imgs, n_way)
loss = criterion(predictions, query_labels)
# 精度計算
pred = predictions.argmax(dim=1)
acc = (pred == query_labels).float().mean()
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_acc += acc.item()
avg_loss = total_loss / len(train_loader)
avg_acc = total_acc / len(train_loader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")
# 使用例
model = MatchingNetwork(input_channels=3, hidden_dim=64)
4. Relation Networks
4.1 学習可能な距離メトリック
Relation Networksは、固定的なユークリッド距離やコサイン類似度の代わりに、学習可能なニューラルネットワークで類似度を計算します。これにより、タスク固有の最適な距離関数を学習できます。
Concatenation] EQ --> C C --> R[関係モジュール
CNN] R --> SC[類似度スコア] style ES fill:#9d4edd style EQ fill:#9d4edd style R fill:#3182ce
関係スコアの計算:
$$ r_{i,j} = g_\phi(\text{concat}(f_\theta(x_i), f_\theta(x_j))) $$ここで:
- $f_\theta$ は特徴抽出器
- $g_\phi$ は関係モジュール(学習可能なCNN)
- $r_{i,j}$ はサンプル$i$と$j$の類似度スコア
4.2 CNNベースの関係モジュール
関係モジュールは、結合された特徴ベクトルから類似度スコアを出力する畳み込みネットワークです。
4.3 実装例
import torch
import torch.nn as nn
import torch.nn.functional as F
class RelationNetwork(nn.Module):
"""Relation Network実装"""
def __init__(self, input_channels=3, feature_dim=64):
super(RelationNetwork, self).__init__()
# 特徴抽出器(エンコーダ)
self.encoder = nn.Sequential(
self._conv_block(input_channels, feature_dim),
self._conv_block(feature_dim, feature_dim),
self._conv_block(feature_dim, feature_dim),
self._conv_block(feature_dim, feature_dim),
)
# 関係モジュール(結合された特徴から類似度を計算)
self.relation_module = nn.Sequential(
self._conv_block(feature_dim * 2, feature_dim),
self._conv_block(feature_dim, feature_dim),
nn.Flatten(),
nn.Linear(feature_dim * 5 * 5, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 1),
nn.Sigmoid() # 類似度スコアを[0, 1]に正規化
)
def _conv_block(self, in_channels, out_channels):
"""CNNブロック"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def forward(self, support_images, query_images, n_way, k_shot):
"""
Args:
support_images: (n_way * k_shot, C, H, W)
query_images: (n_query, C, H, W)
"""
# 特徴抽出
support_features = self.encoder(support_images) # (n_support, D, H, W)
query_features = self.encoder(query_images) # (n_query, D, H, W)
n_support = support_features.size(0)
n_query = query_features.size(0)
D, H, W = support_features.size(1), support_features.size(2), support_features.size(3)
# サポートセットのプロトタイプを計算(各クラスの平均)
support_features_proto = support_features.view(n_way, k_shot, D, H, W).mean(dim=1)
# クエリとプロトタイプのペアを作成
# クエリ特徴を拡張: (n_query, n_way, D, H, W)
query_features_ext = query_features.unsqueeze(1).repeat(1, n_way, 1, 1, 1)
# プロトタイプ特徴を拡張: (n_query, n_way, D, H, W)
support_features_ext = support_features_proto.unsqueeze(0).repeat(n_query, 1, 1, 1, 1)
# 特徴を結合
relation_pairs = torch.cat([query_features_ext, support_features_ext], dim=2)
relation_pairs = relation_pairs.view(-1, D * 2, H, W)
# 関係スコアを計算
relation_scores = self.relation_module(relation_pairs).view(n_query, n_way)
return relation_scores
class MSELoss4RelationNetwork(nn.Module):
"""Relation Network用のMSE Loss"""
def forward(self, relation_scores, labels, n_way):
"""
Args:
relation_scores: (n_query, n_way)
labels: (n_query,)
"""
# One-hotラベルを作成
one_hot_labels = F.one_hot(labels, n_way).float()
# MSE Loss
loss = F.mse_loss(relation_scores, one_hot_labels)
return loss
# 学習関数
def train_relation(model, train_loader, num_epochs=100, n_way=5, k_shot=1):
"""Relation Networkの学習"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = MSELoss4RelationNetwork()
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_acc = 0
for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(train_loader):
support_imgs = support_imgs.to(device)
query_imgs = query_imgs.to(device)
query_labels = query_labels.to(device)
# Forward
relation_scores = model(support_imgs, query_imgs, n_way, k_shot)
loss = criterion(relation_scores, query_labels, n_way)
# 精度計算
pred = relation_scores.argmax(dim=1)
acc = (pred == query_labels).float().mean()
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_acc += acc.item()
avg_loss = total_loss / len(train_loader)
avg_acc = total_acc / len(train_loader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")
# 使用例
model = RelationNetwork(input_channels=3, feature_dim=64)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
5. 実践:手法の比較実験
5.1 miniImageNetデータセット
miniImageNetは、ImageNetのサブセットで、Few-Shot学習のベンチマークとして広く使用されています。
データセット構成:
| 分割 | クラス数 | クラスあたりサンプル数 | 用途 |
|---|---|---|---|
| Train | 64 | 600 | メタ学習の訓練 |
| Validation | 16 | 600 | ハイパーパラメータ調整 |
| Test | 20 | 600 | 最終性能評価 |
5.2 5-way 1-shot/5-shot評価
import torch
import numpy as np
from torch.utils.data import DataLoader
def evaluate_few_shot(model, test_loader, n_way=5, k_shot=1, n_query=15, n_episodes=600):
"""
Few-Shot学習モデルの評価
Args:
model: 評価するモデル(Prototypical, Matching, Relationのいずれか)
test_loader: テストデータローダー
n_way: クラス数
k_shot: サポートセットのサンプル数
n_query: クエリセットのサンプル数
n_episodes: 評価エピソード数
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
accuracies = []
with torch.no_grad():
for episode in range(n_episodes):
# エピソードデータをサンプリング
support_imgs, support_labels, query_imgs, query_labels = next(iter(test_loader))
support_imgs = support_imgs.to(device)
support_labels = support_labels.to(device)
query_imgs = query_imgs.to(device)
query_labels = query_labels.to(device)
# モデルによって異なる推論方法
if hasattr(model, 'relation_module'): # Relation Network
predictions = model(support_imgs, query_imgs, n_way, k_shot)
pred_labels = predictions.argmax(dim=1)
else: # Prototypical or Matching Network
logits = model(support_imgs, support_labels, query_imgs, n_way, k_shot)
pred_labels = logits.argmax(dim=1)
# 精度計算
acc = (pred_labels == query_labels).float().mean().item()
accuracies.append(acc)
if (episode + 1) % 100 == 0:
current_avg = np.mean(accuracies)
current_std = np.std(accuracies)
print(f"Episode [{episode+1}/{n_episodes}], "
f"Acc: {current_avg:.4f} ± {1.96 * current_std / np.sqrt(len(accuracies)):.4f}")
# 最終結果(95%信頼区間)
mean_acc = np.mean(accuracies)
std_acc = np.std(accuracies)
confidence_interval = 1.96 * std_acc / np.sqrt(n_episodes)
return mean_acc, confidence_interval
# データローダー設定の例
class FewShotDataLoader:
"""Few-Shot学習用のデータローダー"""
def __init__(self, dataset, n_way=5, k_shot=1, n_query=15):
self.dataset = dataset
self.n_way = n_way
self.k_shot = k_shot
self.n_query = n_query
def sample_episode(self):
"""1エピソード分のデータをサンプリング"""
# n_wayクラスをランダムに選択
classes = np.random.choice(len(self.dataset.classes), self.n_way, replace=False)
support_imgs = []
support_labels = []
query_imgs = []
query_labels = []
for i, cls in enumerate(classes):
# クラスからk_shot + n_queryサンプルを選択
cls_samples = self.dataset.get_samples_by_class(cls)
indices = np.random.choice(len(cls_samples), self.k_shot + self.n_query, replace=False)
# サポートセット
support_imgs.extend([cls_samples[idx] for idx in indices[:self.k_shot]])
support_labels.extend([i] * self.k_shot)
# クエリセット
query_imgs.extend([cls_samples[idx] for idx in indices[self.k_shot:]])
query_labels.extend([i] * self.n_query)
return (torch.stack(support_imgs), torch.tensor(support_labels),
torch.stack(query_imgs), torch.tensor(query_labels))
5.3 精度比較と考察
# 各手法の比較実験
import pandas as pd
import matplotlib.pyplot as plt
def compare_few_shot_methods(test_loader, n_way=5, k_shot_list=[1, 5]):
"""複数のFew-Shot学習手法を比較"""
results = []
for k_shot in k_shot_list:
print(f"\n{'='*50}")
print(f"{n_way}-way {k_shot}-shot evaluation")
print(f"{'='*50}\n")
# Prototypical Network
print("Evaluating Prototypical Network...")
proto_model = PrototypicalNetwork(input_channels=3, hidden_dim=64)
proto_acc, proto_ci = evaluate_few_shot(proto_model, test_loader, n_way, k_shot)
results.append({
'Method': 'Prototypical',
'Setting': f'{n_way}-way {k_shot}-shot',
'Accuracy': proto_acc,
'CI': proto_ci
})
print(f"Prototypical Network: {proto_acc:.4f} ± {proto_ci:.4f}\n")
# Matching Network
print("Evaluating Matching Network...")
match_model = MatchingNetwork(input_channels=3, hidden_dim=64)
match_acc, match_ci = evaluate_few_shot(match_model, test_loader, n_way, k_shot)
results.append({
'Method': 'Matching',
'Setting': f'{n_way}-way {k_shot}-shot',
'Accuracy': match_acc,
'CI': match_ci
})
print(f"Matching Network: {match_acc:.4f} ± {match_ci:.4f}\n")
# Relation Network
print("Evaluating Relation Network...")
relation_model = RelationNetwork(input_channels=3, feature_dim=64)
relation_acc, relation_ci = evaluate_few_shot(relation_model, test_loader, n_way, k_shot)
results.append({
'Method': 'Relation',
'Setting': f'{n_way}-way {k_shot}-shot',
'Accuracy': relation_acc,
'CI': relation_ci
})
print(f"Relation Network: {relation_acc:.4f} ± {relation_ci:.4f}\n")
return pd.DataFrame(results)
# 結果の可視化
def plot_comparison(results_df):
"""比較結果を可視化"""
fig, ax = plt.subplots(figsize=(10, 6))
# 1-shotと5-shotで分離
settings = results_df['Setting'].unique()
x = np.arange(len(results_df['Method'].unique()))
width = 0.35
for i, setting in enumerate(settings):
data = results_df[results_df['Setting'] == setting]
accuracies = data['Accuracy'].values
cis = data['CI'].values
ax.bar(x + i * width, accuracies, width,
yerr=cis, label=setting, capsize=5)
ax.set_xlabel('Method')
ax.set_ylabel('Accuracy')
ax.set_title('Few-Shot Learning Methods Comparison on miniImageNet')
ax.set_xticks(x + width / 2)
ax.set_xticklabels(results_df['Method'].unique())
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('few_shot_comparison.png', dpi=300, bbox_inches='tight')
plt.show()
# 実行例
# results_df = compare_few_shot_methods(test_loader, n_way=5, k_shot_list=[1, 5])
# plot_comparison(results_df)
典型的な結果(miniImageNet):
| 手法 | 5-way 1-shot | 5-way 5-shot | 主な特徴 |
|---|---|---|---|
| Prototypical Networks | 49.42% ± 0.78% | 68.20% ± 0.66% | シンプルで効率的 |
| Matching Networks | 46.60% ± 0.78% | 60.00% ± 0.71% | Attention機構 |
| Relation Networks | 50.44% ± 0.82% | 65.32% ± 0.70% | 学習可能な距離 |
5.4 考察と手法選択のガイドライン
Prototypical Networks
- 長所: 実装がシンプル、計算効率が良い、多くのタスクで高精度
- 短所: 固定的なユークリッド距離に依存
- 推奨用途: ベースライン手法として、リソースが限られた環境
Matching Networks
- 長所: Attention機構でサポートセット全体を考慮
- 短所: LSTMによる計算コストが高い
- 推奨用途: サポートセット間の関係性が重要なタスク
Relation Networks
- 長所: タスク固有の最適な距離関数を学習可能
- 短所: パラメータ数が多く、学習に時間がかかる
- 推奨用途: 複雑な類似度が必要なタスク、十分なデータがある場合
実践的なアドバイス: 新しいタスクでは、まずPrototypical Networksをベースラインとして試し、性能が不十分な場合にRelation Networksを検討するのが効率的です。
演習問題
演習1:Siamese Networkの改良
提供されたSiamese NetworkにTriplet Lossを実装し、Contrastive Lossとの性能を比較してください。Triplet Lossは、アンカー、ポジティブ、ネガティブの3つのサンプルを使用します。
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
# TODO: Triplet Lossを実装
# ヒント: L = max(0, d(a,p) - d(a,n) + margin)
pass
演習2:Prototypical Networksの拡張
Prototypical Networksを拡張し、各クラスのプロトタイプを単純な平均ではなく、Attention機構を使った加重平均で計算してください。これにより、ノイズの多いサンプルの影響を減らせます。
def compute_prototypes_with_attention(self, embeddings, labels, n_way):
"""Attention機構を使ったプロトタイプ計算"""
# TODO: 実装
# ヒント: サンプル間の類似度に基づいてattention重みを計算
pass
演習3:マルチモーダルFew-Shot学習
画像とテキストの両方を入力とするマルチモーダルPrototypical Networkを設計してください。画像にはCNN、テキストにはTransformerを使用し、両方の埋め込みを結合します。
class MultimodalPrototypicalNetwork(nn.Module):
def __init__(self):
super().__init__()
# TODO: 画像エンコーダとテキストエンコーダを定義
pass
def forward(self, images, texts):
# TODO: マルチモーダル埋め込みを計算
pass
演習4:Few-Shot学習の実アプリケーション
医療画像診断のシナリオを想定し、限られた症例画像(各疾患5枚程度)から新しい疾患を分類するシステムを設計してください。どの手法が最適か、その理由とともに説明してください。また、データ拡張やドメイン適応の戦略も考えてください。
まとめ
この章では、Few-Shot学習の主要な手法について学びました:
- Siamese Networks: ペア学習とContrastive Lossによる類似度学習の基礎
- Prototypical Networks: プロトタイプベースのシンプルで効果的な分類
- Matching Networks: Attention機構によるコンテキスト考慮型分類
- Relation Networks: 学習可能な距離メトリックによる柔軟な類似度計算
これらの手法は、それぞれ異なる強みを持ち、タスクやリソースに応じて選択できます。次章では、これらの手法をより高度な最適化アルゴリズム(MAML等)と組み合わせる方法を学びます。