第3章:Few-Shot学習手法

メトリック学習に基づく少数サンプル分類アーキテクチャ

📖 読了時間: 32分 📊 難易度: 中級〜上級 💻 コード例: 8個 📝 演習問題: 4問

学習目標

この章を読むことで、以下を習得できます:

1. Siamese Networks

1.1 ペア学習の原理

Siamese Networks(シャムネットワーク)は、2つの入力を同じネットワーク(重み共有)で処理し、その類似度を学習するアーキテクチャです。Few-Shot学習において、サンプル間の関係性を直接学習する基本的な手法です。

graph LR A[画像1] --> B[CNN] C[画像2] --> D[CNN] B --> E[埋め込み1] D --> F[埋め込み2] E --> G[距離計算] F --> G G --> H[類似度スコア] style B fill:#9d4edd style D fill:#9d4edd style G fill:#3182ce

主要な特徴:

1.2 Contrastive Loss

Contrastive Lossは、同じクラスのペアは近く、異なるクラスのペアは遠くなるように学習する損失関数です。

数式定義:

$$ \mathcal{L}(x_1, x_2, y) = y \cdot d(x_1, x_2)^2 + (1-y) \cdot \max(0, m - d(x_1, x_2))^2 $$

ここで:

1.3 画像ペアでの類似度学習

import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    """Siamese Network実装"""

    def __init__(self, input_channels=3, embedding_dim=128):
        super(SiameseNetwork, self).__init__()

        # 共有される特徴抽出器
        self.encoder = nn.Sequential(
            # Conv Block 1
            nn.Conv2d(input_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Conv Block 2
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Conv Block 3
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Flatten(),
        )

        # 全結合層で埋め込み空間へ
        self.fc = nn.Sequential(
            nn.Linear(256 * 8 * 8, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, embedding_dim)
        )

    def forward_one(self, x):
        """1つの入力を埋め込み空間へ変換"""
        x = self.encoder(x)
        x = self.fc(x)
        return F.normalize(x, p=2, dim=1)  # L2正規化

    def forward(self, x1, x2):
        """ペア入力を処理"""
        emb1 = self.forward_one(x1)
        emb2 = self.forward_one(x2)
        return emb1, emb2

class ContrastiveLoss(nn.Module):
    """Contrastive Loss実装"""

    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, emb1, emb2, label):
        """
        Args:
            emb1, emb2: 埋め込みベクトル (batch_size, embedding_dim)
            label: ラベル (1=同じクラス, 0=異なるクラス)
        """
        # ユークリッド距離
        distance = F.pairwise_distance(emb1, emb2, p=2)

        # Contrastive Loss
        loss_positive = label * torch.pow(distance, 2)
        loss_negative = (1 - label) * torch.pow(
            torch.clamp(self.margin - distance, min=0.0), 2
        )

        loss = torch.mean(loss_positive + loss_negative)
        return loss

# 学習例
def train_siamese(model, train_loader, num_epochs=10):
    """Siamese Networkの学習"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = ContrastiveLoss(margin=1.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, (img1, img2, labels) in enumerate(train_loader):
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

            # Forward
            emb1, emb2 = model(img1, img2)
            loss = criterion(emb1, emb2, labels.float())

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# 使用例
model = SiameseNetwork(input_channels=3, embedding_dim=128)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

2. Prototypical Networks

2.1 プロトタイプ(クラス中心)の計算

Prototypical Networksは、各クラスの「プロトタイプ」(代表的な埋め込みベクトル)を計算し、新しいサンプルを最も近いプロトタイプのクラスに分類します。

graph TB subgraph Support Set A1[クラスA サンプル1] --> E1[エンコーダ] A2[クラスA サンプル2] --> E2[エンコーダ] B1[クラスB サンプル1] --> E3[エンコーダ] B2[クラスB サンプル2] --> E4[エンコーダ] end E1 --> PA[プロトタイプA
平均] E2 --> PA E3 --> PB[プロトタイプB
平均] E4 --> PB Q[Query] --> EQ[エンコーダ] EQ --> D[距離計算] PA --> D PB --> D D --> C[分類] style PA fill:#9d4edd style PB fill:#9d4edd style D fill:#3182ce

プロトタイプの定義:

$$ c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\theta(x_i) $$

ここで:

2.2 ユークリッド距離ベースの分類

クエリサンプル$x$のクラス確率はsoftmaxで計算されます:

$$ P(y=k|x) = \frac{\exp(-d(f_\theta(x), c_k))}{\sum_{k'} \exp(-d(f_\theta(x), c_{k'}))} $$

2.3 PyTorch実装

import torch
import torch.nn as nn
import torch.nn.functional as F

class PrototypicalNetwork(nn.Module):
    """Prototypical Network実装"""

    def __init__(self, input_channels=3, hidden_dim=64):
        super(PrototypicalNetwork, self).__init__()

        # 特徴抽出器(4層CNNブロック)
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2)
            )

        self.encoder = nn.Sequential(
            conv_block(input_channels, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            nn.Flatten()
        )

    def forward(self, support_images, support_labels, query_images, n_way, k_shot):
        """
        Args:
            support_images: (n_way * k_shot, C, H, W)
            support_labels: (n_way * k_shot,)
            query_images: (n_query, C, H, W)
            n_way: クラス数
            k_shot: クラスあたりのサンプル数
        """
        # サポートセットとクエリセットの埋め込み
        support_embeddings = self.encoder(support_images)
        query_embeddings = self.encoder(query_images)

        # プロトタイプの計算(各クラスの平均)
        prototypes = self.compute_prototypes(
            support_embeddings, support_labels, n_way
        )

        # クエリとプロトタイプ間の距離を計算
        distances = self.euclidean_distance(query_embeddings, prototypes)

        # 負の距離をlogitsとして使用
        logits = -distances
        return logits

    def compute_prototypes(self, embeddings, labels, n_way):
        """各クラスのプロトタイプを計算"""
        prototypes = torch.zeros(n_way, embeddings.size(1), device=embeddings.device)

        for k in range(n_way):
            # クラスkに属するサンプルのマスク
            mask = (labels == k)
            # クラスkのサンプルの平均を計算
            prototypes[k] = embeddings[mask].mean(dim=0)

        return prototypes

    def euclidean_distance(self, x, y):
        """
        ユークリッド距離の計算
        Args:
            x: (n_query, d)
            y: (n_way, d)
        Returns:
            distances: (n_query, n_way)
        """
        n = x.size(0)
        m = y.size(0)
        d = x.size(1)

        # ブロードキャストで効率的に計算
        x = x.unsqueeze(1).expand(n, m, d)  # (n, m, d)
        y = y.unsqueeze(0).expand(n, m, d)  # (n, m, d)

        return torch.pow(x - y, 2).sum(2)  # (n, m)

def train_prototypical(model, train_loader, num_epochs=100, n_way=5, k_shot=1):
    """Prototypical Networkの学習"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_acc = 0

        for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(train_loader):
            support_imgs = support_imgs.to(device)
            support_labels = support_labels.to(device)
            query_imgs = query_imgs.to(device)
            query_labels = query_labels.to(device)

            # Forward
            logits = model(support_imgs, support_labels, query_imgs, n_way, k_shot)
            loss = criterion(logits, query_labels)

            # 精度計算
            pred = logits.argmax(dim=1)
            acc = (pred == query_labels).float().mean()

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc.item()

        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")

# 使用例
model = PrototypicalNetwork(input_channels=3, hidden_dim=64)
print(f"Model architecture:\n{model}")

3. Matching Networks

3.1 Attention機構の活用

Matching Networksは、クエリサンプルとサポートセットの各サンプル間でAttention機構を使用し、加重平均でクラス確率を計算します。これにより、サポートセット全体のコンテキストを考慮した分類が可能になります。

graph TB subgraph Support Set S1[サポート1] --> ES1[埋め込み] S2[サポート2] --> ES2[埋め込み] S3[サポート3] --> ES3[埋め込み] end Q[クエリ] --> EQ[埋め込み + LSTM] EQ --> A1[Attention
重み1] EQ --> A2[Attention
重み2] EQ --> A3[Attention
重み3] ES1 --> A1 ES2 --> A2 ES3 --> A3 A1 --> W[加重平均] A2 --> W A3 --> W W --> P[予測] style EQ fill:#9d4edd style W fill:#3182ce

3.2 Full Context Embeddings

Matching Networksの重要な特徴は、サポートセット全体のコンテキストを考慮した埋め込みを生成することです。これはLSTMなどの系列モデルで実現されます。

Attention重みの計算:

$$ a(x, x_i) = \frac{\exp(c(\hat{x}, \hat{x}_i))}{\sum_j \exp(c(\hat{x}, \hat{x}_j))} $$

予測分布:

$$ P(y|x, S) = \sum_{i=1}^k a(x, x_i) y_i $$

3.3 実装と評価

import torch
import torch.nn as nn
import torch.nn.functional as F

class MatchingNetwork(nn.Module):
    """Matching Network実装"""

    def __init__(self, input_channels=3, hidden_dim=64, lstm_layers=1):
        super(MatchingNetwork, self).__init__()

        # 特徴抽出器(CNNエンコーダ)
        self.encoder = nn.Sequential(
            self._conv_block(input_channels, hidden_dim),
            self._conv_block(hidden_dim, hidden_dim),
            self._conv_block(hidden_dim, hidden_dim),
            self._conv_block(hidden_dim, hidden_dim),
        )

        # 埋め込み次元を計算
        self.embedding_dim = hidden_dim * 5 * 5

        # Full Context Embeddings用のLSTM
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.embedding_dim,
            num_layers=lstm_layers,
            bidirectional=True,
            batch_first=True
        )

        # 双方向LSTMの出力を元の次元に変換
        self.fc = nn.Linear(self.embedding_dim * 2, self.embedding_dim)

    def _conv_block(self, in_channels, out_channels):
        """CNNブロック"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

    def encode(self, x):
        """画像を埋め込みベクトルに変換"""
        batch_size = x.size(0)
        x = self.encoder(x)
        x = x.view(batch_size, -1)
        return x

    def full_context_embeddings(self, embeddings):
        """
        LSTMでサポートセット全体のコンテキストを考慮
        Args:
            embeddings: (batch_size, seq_len, embedding_dim)
        """
        output, _ = self.lstm(embeddings)
        output = self.fc(output)
        return output

    def attention(self, query_emb, support_emb):
        """
        Attention重みを計算
        Args:
            query_emb: (n_query, embedding_dim)
            support_emb: (n_support, embedding_dim)
        Returns:
            attention_weights: (n_query, n_support)
        """
        # コサイン類似度を計算
        query_norm = F.normalize(query_emb, p=2, dim=1)
        support_norm = F.normalize(support_emb, p=2, dim=1)

        similarities = torch.mm(query_norm, support_norm.t())

        # Softmaxでattention重みに変換
        attention_weights = F.softmax(similarities, dim=1)
        return attention_weights

    def forward(self, support_images, support_labels, query_images, n_way):
        """
        Args:
            support_images: (n_way * k_shot, C, H, W)
            support_labels: (n_way * k_shot,) one-hot encoded
            query_images: (n_query, C, H, W)
        """
        # 埋め込みを計算
        support_emb = self.encode(support_images)  # (n_support, emb_dim)
        query_emb = self.encode(query_images)      # (n_query, emb_dim)

        # Full Context Embeddings(サポートセットのみ)
        support_emb_context = self.full_context_embeddings(
            support_emb.unsqueeze(0)  # (1, n_support, emb_dim)
        ).squeeze(0)  # (n_support, emb_dim)

        # Attention重みを計算
        attention_weights = self.attention(query_emb, support_emb_context)

        # One-hotラベルに変換
        support_labels_one_hot = F.one_hot(support_labels, n_way).float()

        # Attention重み付き予測
        predictions = torch.mm(attention_weights, support_labels_one_hot)

        return predictions

# 学習関数
def train_matching(model, train_loader, num_epochs=100, n_way=5):
    """Matching Networkの学習"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_acc = 0

        for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(train_loader):
            support_imgs = support_imgs.to(device)
            support_labels = support_labels.to(device)
            query_imgs = query_imgs.to(device)
            query_labels = query_labels.to(device)

            # Forward
            predictions = model(support_imgs, support_labels, query_imgs, n_way)
            loss = criterion(predictions, query_labels)

            # 精度計算
            pred = predictions.argmax(dim=1)
            acc = (pred == query_labels).float().mean()

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc.item()

        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")

# 使用例
model = MatchingNetwork(input_channels=3, hidden_dim=64)

4. Relation Networks

4.1 学習可能な距離メトリック

Relation Networksは、固定的なユークリッド距離やコサイン類似度の代わりに、学習可能なニューラルネットワークで類似度を計算します。これにより、タスク固有の最適な距離関数を学習できます。

graph TB S[サポート] --> ES[特徴抽出器] Q[クエリ] --> EQ[特徴抽出器] ES --> C[結合
Concatenation] EQ --> C C --> R[関係モジュール
CNN] R --> SC[類似度スコア] style ES fill:#9d4edd style EQ fill:#9d4edd style R fill:#3182ce

関係スコアの計算:

$$ r_{i,j} = g_\phi(\text{concat}(f_\theta(x_i), f_\theta(x_j))) $$

ここで:

4.2 CNNベースの関係モジュール

関係モジュールは、結合された特徴ベクトルから類似度スコアを出力する畳み込みネットワークです。

4.3 実装例

import torch
import torch.nn as nn
import torch.nn.functional as F

class RelationNetwork(nn.Module):
    """Relation Network実装"""

    def __init__(self, input_channels=3, feature_dim=64):
        super(RelationNetwork, self).__init__()

        # 特徴抽出器(エンコーダ)
        self.encoder = nn.Sequential(
            self._conv_block(input_channels, feature_dim),
            self._conv_block(feature_dim, feature_dim),
            self._conv_block(feature_dim, feature_dim),
            self._conv_block(feature_dim, feature_dim),
        )

        # 関係モジュール(結合された特徴から類似度を計算)
        self.relation_module = nn.Sequential(
            self._conv_block(feature_dim * 2, feature_dim),
            self._conv_block(feature_dim, feature_dim),
            nn.Flatten(),
            nn.Linear(feature_dim * 5 * 5, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 類似度スコアを[0, 1]に正規化
        )

    def _conv_block(self, in_channels, out_channels):
        """CNNブロック"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

    def forward(self, support_images, query_images, n_way, k_shot):
        """
        Args:
            support_images: (n_way * k_shot, C, H, W)
            query_images: (n_query, C, H, W)
        """
        # 特徴抽出
        support_features = self.encoder(support_images)  # (n_support, D, H, W)
        query_features = self.encoder(query_images)      # (n_query, D, H, W)

        n_support = support_features.size(0)
        n_query = query_features.size(0)
        D, H, W = support_features.size(1), support_features.size(2), support_features.size(3)

        # サポートセットのプロトタイプを計算(各クラスの平均)
        support_features_proto = support_features.view(n_way, k_shot, D, H, W).mean(dim=1)

        # クエリとプロトタイプのペアを作成
        # クエリ特徴を拡張: (n_query, n_way, D, H, W)
        query_features_ext = query_features.unsqueeze(1).repeat(1, n_way, 1, 1, 1)

        # プロトタイプ特徴を拡張: (n_query, n_way, D, H, W)
        support_features_ext = support_features_proto.unsqueeze(0).repeat(n_query, 1, 1, 1, 1)

        # 特徴を結合
        relation_pairs = torch.cat([query_features_ext, support_features_ext], dim=2)
        relation_pairs = relation_pairs.view(-1, D * 2, H, W)

        # 関係スコアを計算
        relation_scores = self.relation_module(relation_pairs).view(n_query, n_way)

        return relation_scores

class MSELoss4RelationNetwork(nn.Module):
    """Relation Network用のMSE Loss"""

    def forward(self, relation_scores, labels, n_way):
        """
        Args:
            relation_scores: (n_query, n_way)
            labels: (n_query,)
        """
        # One-hotラベルを作成
        one_hot_labels = F.one_hot(labels, n_way).float()

        # MSE Loss
        loss = F.mse_loss(relation_scores, one_hot_labels)
        return loss

# 学習関数
def train_relation(model, train_loader, num_epochs=100, n_way=5, k_shot=1):
    """Relation Networkの学習"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = MSELoss4RelationNetwork()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_acc = 0

        for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(train_loader):
            support_imgs = support_imgs.to(device)
            query_imgs = query_imgs.to(device)
            query_labels = query_labels.to(device)

            # Forward
            relation_scores = model(support_imgs, query_imgs, n_way, k_shot)
            loss = criterion(relation_scores, query_labels, n_way)

            # 精度計算
            pred = relation_scores.argmax(dim=1)
            acc = (pred == query_labels).float().mean()

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc.item()

        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")

# 使用例
model = RelationNetwork(input_channels=3, feature_dim=64)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

5. 実践:手法の比較実験

5.1 miniImageNetデータセット

miniImageNetは、ImageNetのサブセットで、Few-Shot学習のベンチマークとして広く使用されています。

データセット構成:

分割 クラス数 クラスあたりサンプル数 用途
Train 64 600 メタ学習の訓練
Validation 16 600 ハイパーパラメータ調整
Test 20 600 最終性能評価

5.2 5-way 1-shot/5-shot評価

import torch
import numpy as np
from torch.utils.data import DataLoader

def evaluate_few_shot(model, test_loader, n_way=5, k_shot=1, n_query=15, n_episodes=600):
    """
    Few-Shot学習モデルの評価

    Args:
        model: 評価するモデル(Prototypical, Matching, Relationのいずれか)
        test_loader: テストデータローダー
        n_way: クラス数
        k_shot: サポートセットのサンプル数
        n_query: クエリセットのサンプル数
        n_episodes: 評価エピソード数
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    accuracies = []

    with torch.no_grad():
        for episode in range(n_episodes):
            # エピソードデータをサンプリング
            support_imgs, support_labels, query_imgs, query_labels = next(iter(test_loader))

            support_imgs = support_imgs.to(device)
            support_labels = support_labels.to(device)
            query_imgs = query_imgs.to(device)
            query_labels = query_labels.to(device)

            # モデルによって異なる推論方法
            if hasattr(model, 'relation_module'):  # Relation Network
                predictions = model(support_imgs, query_imgs, n_way, k_shot)
                pred_labels = predictions.argmax(dim=1)
            else:  # Prototypical or Matching Network
                logits = model(support_imgs, support_labels, query_imgs, n_way, k_shot)
                pred_labels = logits.argmax(dim=1)

            # 精度計算
            acc = (pred_labels == query_labels).float().mean().item()
            accuracies.append(acc)

            if (episode + 1) % 100 == 0:
                current_avg = np.mean(accuracies)
                current_std = np.std(accuracies)
                print(f"Episode [{episode+1}/{n_episodes}], "
                      f"Acc: {current_avg:.4f} ± {1.96 * current_std / np.sqrt(len(accuracies)):.4f}")

    # 最終結果(95%信頼区間)
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies)
    confidence_interval = 1.96 * std_acc / np.sqrt(n_episodes)

    return mean_acc, confidence_interval

# データローダー設定の例
class FewShotDataLoader:
    """Few-Shot学習用のデータローダー"""

    def __init__(self, dataset, n_way=5, k_shot=1, n_query=15):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query

    def sample_episode(self):
        """1エピソード分のデータをサンプリング"""
        # n_wayクラスをランダムに選択
        classes = np.random.choice(len(self.dataset.classes), self.n_way, replace=False)

        support_imgs = []
        support_labels = []
        query_imgs = []
        query_labels = []

        for i, cls in enumerate(classes):
            # クラスからk_shot + n_queryサンプルを選択
            cls_samples = self.dataset.get_samples_by_class(cls)
            indices = np.random.choice(len(cls_samples), self.k_shot + self.n_query, replace=False)

            # サポートセット
            support_imgs.extend([cls_samples[idx] for idx in indices[:self.k_shot]])
            support_labels.extend([i] * self.k_shot)

            # クエリセット
            query_imgs.extend([cls_samples[idx] for idx in indices[self.k_shot:]])
            query_labels.extend([i] * self.n_query)

        return (torch.stack(support_imgs), torch.tensor(support_labels),
                torch.stack(query_imgs), torch.tensor(query_labels))

5.3 精度比較と考察

# 各手法の比較実験
import pandas as pd
import matplotlib.pyplot as plt

def compare_few_shot_methods(test_loader, n_way=5, k_shot_list=[1, 5]):
    """複数のFew-Shot学習手法を比較"""

    results = []

    for k_shot in k_shot_list:
        print(f"\n{'='*50}")
        print(f"{n_way}-way {k_shot}-shot evaluation")
        print(f"{'='*50}\n")

        # Prototypical Network
        print("Evaluating Prototypical Network...")
        proto_model = PrototypicalNetwork(input_channels=3, hidden_dim=64)
        proto_acc, proto_ci = evaluate_few_shot(proto_model, test_loader, n_way, k_shot)
        results.append({
            'Method': 'Prototypical',
            'Setting': f'{n_way}-way {k_shot}-shot',
            'Accuracy': proto_acc,
            'CI': proto_ci
        })
        print(f"Prototypical Network: {proto_acc:.4f} ± {proto_ci:.4f}\n")

        # Matching Network
        print("Evaluating Matching Network...")
        match_model = MatchingNetwork(input_channels=3, hidden_dim=64)
        match_acc, match_ci = evaluate_few_shot(match_model, test_loader, n_way, k_shot)
        results.append({
            'Method': 'Matching',
            'Setting': f'{n_way}-way {k_shot}-shot',
            'Accuracy': match_acc,
            'CI': match_ci
        })
        print(f"Matching Network: {match_acc:.4f} ± {match_ci:.4f}\n")

        # Relation Network
        print("Evaluating Relation Network...")
        relation_model = RelationNetwork(input_channels=3, feature_dim=64)
        relation_acc, relation_ci = evaluate_few_shot(relation_model, test_loader, n_way, k_shot)
        results.append({
            'Method': 'Relation',
            'Setting': f'{n_way}-way {k_shot}-shot',
            'Accuracy': relation_acc,
            'CI': relation_ci
        })
        print(f"Relation Network: {relation_acc:.4f} ± {relation_ci:.4f}\n")

    return pd.DataFrame(results)

# 結果の可視化
def plot_comparison(results_df):
    """比較結果を可視化"""
    fig, ax = plt.subplots(figsize=(10, 6))

    # 1-shotと5-shotで分離
    settings = results_df['Setting'].unique()
    x = np.arange(len(results_df['Method'].unique()))
    width = 0.35

    for i, setting in enumerate(settings):
        data = results_df[results_df['Setting'] == setting]
        accuracies = data['Accuracy'].values
        cis = data['CI'].values

        ax.bar(x + i * width, accuracies, width,
               yerr=cis, label=setting, capsize=5)

    ax.set_xlabel('Method')
    ax.set_ylabel('Accuracy')
    ax.set_title('Few-Shot Learning Methods Comparison on miniImageNet')
    ax.set_xticks(x + width / 2)
    ax.set_xticklabels(results_df['Method'].unique())
    ax.legend()
    ax.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig('few_shot_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

# 実行例
# results_df = compare_few_shot_methods(test_loader, n_way=5, k_shot_list=[1, 5])
# plot_comparison(results_df)

典型的な結果(miniImageNet):

手法 5-way 1-shot 5-way 5-shot 主な特徴
Prototypical Networks 49.42% ± 0.78% 68.20% ± 0.66% シンプルで効率的
Matching Networks 46.60% ± 0.78% 60.00% ± 0.71% Attention機構
Relation Networks 50.44% ± 0.82% 65.32% ± 0.70% 学習可能な距離

5.4 考察と手法選択のガイドライン

Prototypical Networks

Matching Networks

Relation Networks

実践的なアドバイス: 新しいタスクでは、まずPrototypical Networksをベースラインとして試し、性能が不十分な場合にRelation Networksを検討するのが効率的です。

演習問題

演習1:Siamese Networkの改良

提供されたSiamese NetworkにTriplet Lossを実装し、Contrastive Lossとの性能を比較してください。Triplet Lossは、アンカー、ポジティブ、ネガティブの3つのサンプルを使用します。

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        # TODO: Triplet Lossを実装
        # ヒント: L = max(0, d(a,p) - d(a,n) + margin)
        pass
演習2:Prototypical Networksの拡張

Prototypical Networksを拡張し、各クラスのプロトタイプを単純な平均ではなく、Attention機構を使った加重平均で計算してください。これにより、ノイズの多いサンプルの影響を減らせます。

def compute_prototypes_with_attention(self, embeddings, labels, n_way):
    """Attention機構を使ったプロトタイプ計算"""
    # TODO: 実装
    # ヒント: サンプル間の類似度に基づいてattention重みを計算
    pass
演習3:マルチモーダルFew-Shot学習

画像とテキストの両方を入力とするマルチモーダルPrototypical Networkを設計してください。画像にはCNN、テキストにはTransformerを使用し、両方の埋め込みを結合します。

class MultimodalPrototypicalNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: 画像エンコーダとテキストエンコーダを定義
        pass

    def forward(self, images, texts):
        # TODO: マルチモーダル埋め込みを計算
        pass
演習4:Few-Shot学習の実アプリケーション

医療画像診断のシナリオを想定し、限られた症例画像(各疾患5枚程度)から新しい疾患を分類するシステムを設計してください。どの手法が最適か、その理由とともに説明してください。また、データ拡張やドメイン適応の戦略も考えてください。

まとめ

この章では、Few-Shot学習の主要な手法について学びました:

これらの手法は、それぞれ異なる強みを持ち、タスクやリソースに応じて選択できます。次章では、これらの手法をより高度な最適化アルゴリズム(MAML等)と組み合わせる方法を学びます。

免責事項