1. Self-Attention機構
Self-Attention(自己注意機構)は、Transformerアーキテクチャの中核をなすメカニズムです。入力シーケンス内の各トークンが、他のすべてのトークンとの関連性を動的に計算することで、長距離の依存関係を効果的に捉えることができます。
1.1 Attentionの直感的理解
Attentionの本質
「The cat sat on the mat because it was tired」という文において、「it」が何を指すかを理解するには、文全体の文脈を考慮する必要があります。Self-Attentionは、「it」と「cat」の間の強い関連性を学習し、適切な参照解決を可能にします。
1.2 数学的定式化
Self-Attentionは、Query(Q)、Key(K)、Value(V)の3つの線形変換を通じて計算されます:
$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$
ここで、$d_k$はKey次元であり、$\sqrt{d_k}$によるスケーリングは勾配の安定化のために重要です。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
"""
Self-Attentionの基本実装
パラメータ:
embed_dim: 埋め込み次元
num_heads: アテンションヘッド数(オプション、Multi-Head用)
"""
def __init__(self, embed_dim: int, num_heads: int = 1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dimはnum_headsで割り切れる必要があります"
# Query, Key, Value の線形変換
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# 出力投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.scale = math.sqrt(self.head_dim)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
Args:
x: 入力テンソル [batch_size, seq_len, embed_dim]
mask: オプションの注意マスク [batch_size, seq_len, seq_len]
Returns:
出力テンソル [batch_size, seq_len, embed_dim]
"""
batch_size, seq_len, _ = x.shape
# Q, K, V を計算
Q = self.q_proj(x) # [batch, seq, embed]
K = self.k_proj(x)
V = self.v_proj(x)
# 注意スコアを計算
# [batch, seq, embed] @ [batch, embed, seq] -> [batch, seq, seq]
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# マスクを適用(因果マスクなど)
if mask is not None:
attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
# Softmaxで確率に変換
attention_weights = F.softmax(attention_scores, dim=-1)
# Valueの重み付け和
output = torch.matmul(attention_weights, V)
return self.out_proj(output)
# 使用例
embed_dim = 512
seq_len = 100
batch_size = 32
attention = SelfAttention(embed_dim)
x = torch.randn(batch_size, seq_len, embed_dim)
output = attention(x)
print(f"出力形状: {output.shape}") # [32, 100, 512]
2. Multi-Head Attention
Multi-Head Attention(MHA)は、複数の「ヘッド」を並列に使用することで、異なる表現部分空間で異なるタイプの依存関係を学習します。
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention の実装
各ヘッドは異なる表現部分空間で注意を学習し、
最終的に連結されて出力される。
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = math.sqrt(self.head_dim)
# 統合されたQKV投影(効率化のため)
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
is_causal: bool = False
) -> tuple:
batch_size, seq_len, _ = x.shape
# QKVを一度に計算
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, head_dim]
Q, K, V = qkv[0], qkv[1], qkv[2]
# 注意スコア計算
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# 因果マスク(デコーダ用)
if is_causal:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device),
diagonal=1
).bool()
attention_scores = attention_scores.masked_fill(causal_mask, float('-inf'))
# 追加のマスク
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# Softmax と Dropout
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 出力計算
output = torch.matmul(attention_weights, V)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
output = self.out_proj(output)
return output, attention_weights
3. 位置エンコーディング
Transformerは本質的に順序を認識しないため、位置情報を明示的に注入する必要があります。位置エンコーディング手法は大きく進化しており、現代のLLMでは回転位置エンコーディング(RoPE)が標準となっています。
3.1 位置エンコーディングの比較
| 手法 | タイプ | 長所 | 使用例 |
|---|---|---|---|
| 正弦波エンコーディング | 絶対位置 | シンプル、学習不要 | オリジナルTransformer |
| 学習位置エンコーディング | 絶対位置 | タスク適応的 | BERT, GPT-2 |
| RoPE | 相対位置 | 外挿性、効率的 | Llama, Qwen, Mistral |
| ALiBi | 相対位置 | 学習不要、高速 | BLOOM, MPT |
3.2 RoPE(回転位置エンコーディング)
RoPEは、位置情報をQueryとKeyベクトルに回転行列として適用します。これにより、注意計算が自然に相対位置をエンコードするようになります。
class RotaryPositionEmbedding(nn.Module):
"""
回転位置エンコーディング(RoPE)の実装
RoPEは位置情報をベクトルの回転として適用し、
相対位置情報を注意計算に自然に組み込む。
利点:
- 学習パラメータなし
- 相対位置のエンコード
- 長いシーケンスへの優れた外挿性
"""
def __init__(self, dim: int, max_seq_len: int = 8192, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# 周波数を事前計算
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# 位置インデックスを事前計算
self._set_cos_sin_cache(max_seq_len)
def _set_cos_sin_cache(self, seq_len: int):
"""cos/sinキャッシュを事前計算"""
positions = torch.arange(seq_len).float()
freqs = torch.outer(positions, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""ベクトルの半分を回転"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
positions: torch.Tensor = None
) -> tuple:
"""
QueryとKeyにRoPEを適用
Args:
q: Query テンソル [batch, heads, seq, head_dim]
k: Key テンソル [batch, heads, seq, head_dim]
positions: オプションの位置インデックス
Returns:
回転後のq, kのタプル
"""
seq_len = q.shape[2]
if positions is None:
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
else:
cos = self.cos_cached[positions]
sin = self.sin_cached[positions]
# ブロードキャスト用に次元を調整
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq, dim]
sin = sin.unsqueeze(0).unsqueeze(0)
# 回転を適用
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
# 使用例
rope = RotaryPositionEmbedding(dim=64, max_seq_len=4096)
q = torch.randn(2, 8, 100, 64) # [batch, heads, seq, head_dim]
k = torch.randn(2, 8, 100, 64)
q_rot, k_rot = rope(q, k)
print(f"回転後の形状: {q_rot.shape}")
4. アーキテクチャの変種
4.1 Encoder-Decoder vs Decoder-Only
| アーキテクチャ | 代表モデル | 最適な用途 |
|---|---|---|
| Encoder-Only | BERT, RoBERTa | 分類、NER、埋め込み |
| Decoder-Only | GPT-4, Llama, Claude | テキスト生成、会話 |
| Encoder-Decoder | T5, BART | 翻訳、要約 |
5. Mixture of Experts(MoE)
MoE(Mixture of Experts)は、Transformerのフィードフォワード層を複数の「エキスパート」ネットワークに置き換え、各入力に対して少数のエキスパートのみを活性化することで、パラメータ効率とスケーラビリティを向上させます。
5.1 MoEの仕組み
MoEの主な利点
- 計算効率: 総パラメータの15-40%のみを各トークンで活性化
- スケーラビリティ: 計算コストを比例的に増やさずにモデルサイズを拡大可能
- 専門化: 各エキスパートは特定のタスクやドメインに特化できる
5.2 代表的なMoEモデル
| モデル | 総パラメータ | 活性パラメータ | エキスパート数 |
|---|---|---|---|
| Mixtral 8x22B | 176B | 39B | 8 |
| DeepSeek-V3 | 671B | 37B | 256 |
| Llama 4 Scout | 109B | 17B | 16 |
| Qwen3-235B-A22B | 235B | 22B | 128 |
class MoELayer(nn.Module):
"""
Mixture of Expertsレイヤーの実装
Transformerのフィードフォワード層を複数のエキスパートに置き換え、
ルーターが各トークンに対して最適なエキスパートを選択する。
"""
def __init__(
self,
embed_dim: int,
num_experts: int = 8,
top_k: int = 2,
expert_dim: int = None
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
expert_dim = expert_dim or embed_dim * 4
# ルーターネットワーク
self.router = nn.Linear(embed_dim, num_experts, bias=False)
# エキスパートネットワーク(各々がFFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(embed_dim, expert_dim),
nn.GELU(),
nn.Linear(expert_dim, embed_dim)
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: 入力テンソル [batch_size, seq_len, embed_dim]
Returns:
出力テンソルとルーター損失(負荷分散用)のタプル
"""
batch_size, seq_len, embed_dim = x.shape
x_flat = x.view(-1, embed_dim) # [batch*seq, embed]
# ルーティングスコアを計算
router_logits = self.router(x_flat) # [batch*seq, num_experts]
router_probs = F.softmax(router_logits, dim=-1)
# Top-kエキスパートを選択
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
# 選択されたエキスパートの確率を正規化
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# 出力を初期化
output = torch.zeros_like(x_flat)
# 各エキスパートを処理
for expert_idx in range(self.num_experts):
# このエキスパートに割り当てられたトークンを見つける
expert_mask = (top_k_indices == expert_idx).any(dim=-1)
if expert_mask.any():
expert_input = x_flat[expert_mask]
expert_output = self.experts[expert_idx](expert_input)
# このエキスパートの重みを取得
expert_weights = torch.where(
top_k_indices[expert_mask] == expert_idx,
top_k_probs[expert_mask],
torch.zeros_like(top_k_probs[expert_mask])
).sum(dim=-1, keepdim=True)
output[expert_mask] += expert_weights * expert_output
# 負荷分散損失(エキスパート利用の均等化)
load_balance_loss = self._compute_load_balance_loss(router_probs)
return output.view(batch_size, seq_len, embed_dim), load_balance_loss
def _compute_load_balance_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
"""エキスパート間の負荷分散を促す補助損失"""
# 各エキスパートの平均利用率
expert_usage = router_probs.mean(dim=0)
# 均等分布からの乖離
target = 1.0 / self.num_experts
return ((expert_usage - target) ** 2).sum()
# 使用例
moe = MoELayer(embed_dim=512, num_experts=8, top_k=2)
x = torch.randn(4, 100, 512)
output, aux_loss = moe(x)
print(f"出力形状: {output.shape}, 補助損失: {aux_loss.item():.4f}")
6. 効率的なAttention機構
6.1 FlashAttention
FlashAttentionは、GPUのメモリ階層を最適化することで、標準的なAttentionと比較して2-4倍の高速化と大幅なメモリ削減を実現します。
FlashAttentionの原理
- タイリング: 入力をブロックに分割し、オンチップSRAMで処理
- 再計算: 中間結果を保存せず、必要時に再計算
- カーネル融合: 複数の操作を1つのGPUカーネルに統合
# FlashAttentionの使用(PyTorch 2.0+)
import torch
import torch.nn.functional as F
def flash_attention_example(q, k, v, is_causal=True):
"""
PyTorch 2.0+のscaled_dot_product_attentionを使用した
FlashAttention互換の実装
FlashAttention v2は以下を自動的に最適化:
- メモリ効率(O(N)のメモリ使用)
- 計算効率(タイリングによる高速化)
- 因果マスクの最適化
"""
# PyTorch 2.0+では自動的に最適な実装を選択
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=None # 自動計算
)
return output
# 実装比較
def benchmark_attention():
"""標準Attention vs FlashAttentionのベンチマーク"""
import time
batch_size = 4
num_heads = 32
seq_len = 4096
head_dim = 128
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
k = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
v = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
# ウォームアップ
_ = flash_attention_example(q, k, v)
torch.cuda.synchronize()
# FlashAttention
start = time.time()
for _ in range(10):
_ = flash_attention_example(q, k, v)
torch.cuda.synchronize()
flash_time = (time.time() - start) / 10
print(f"FlashAttention: {flash_time*1000:.2f}ms")
print(f"シーケンス長: {seq_len}, メモリ使用量: O(N)")
# GPUが利用可能な場合実行
if torch.cuda.is_available():
benchmark_attention()
6.2 Grouped-Query Attention(GQA)
GQAは、複数のQueryヘッドで1つのKey-Valueヘッドを共有することで、KVキャッシュのメモリ使用量を削減します。
| タイプ | Q:KV比率 | KVキャッシュ削減 | 使用モデル |
|---|---|---|---|
| MHA | 1:1 | 0% | GPT-3 |
| GQA | 8:1 | 87.5% | Llama 2/3 |
| MQA | N:1 | 最大 | PaLM |
まとめ
第2章の重要ポイント
- Self-Attention: Transformerの中核、Query-Key-Valueによる動的な関連性計算
- Multi-Head Attention: 複数の表現部分空間で異なる依存関係を学習
- 位置エンコーディング: RoPEが現代LLMの標準、優れた外挿性能
- MoE: パラメータ効率を大幅に向上、同等計算で高性能を実現
- FlashAttention: メモリ効率と速度を両立、長コンテキストを実現