第3章:深層学習による音声認識

CTC、Attention、RNN-T、Whisperを用いた最新の音声認識技術

📚 音声・オーディオ処理入門 ⏱️ 60分 🏷️ ML-D03

この章で学ぶこと

本章では、深層学習による現代的な音声認識(ASR)技術を学びます。CTCやAttentionメカニズムなどの基本的なアプローチから、最新のWhisperモデルまで、実践的なコード例とともに解説します。

1. CTC (Connectionist Temporal Classification)

1.1 CTCとは

CTCは、入力と出力の長さが異なる系列変換問題において、明示的なアライメント情報なしに学習できる手法です。音声認識では、音響特徴量フレーム(入力)とテキスト(出力)の対応関係を自動的に学習します。

CTCの主要概念

1.2 CTCの損失関数

CTC損失は、全ての可能なアライメントパスの確率を周辺化することで計算されます。Forward-Backwardアルゴリズムを用いて効率的に計算できます。

import torch
import torch.nn as nn
import torch.nn.functional as F

class CTCASRModel(nn.Module):
    """CTC損失を用いた音声認識モデル"""

    def __init__(self, input_dim=80, hidden_dim=256, num_classes=29, num_layers=3):
        """
        Args:
            input_dim: 入力特徴量の次元数(MFCCやメルスペクトログラム)
            hidden_dim: LSTM隠れ層の次元数
            num_classes: 出力クラス数(アルファベット + blank)
            num_layers: LSTMの層数
        """
        super(CTCASRModel, self).__init__()

        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

        # 双方向LSTMなので出力次元は2倍
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x, lengths):
        """
        Args:
            x: (batch, time, features)
            lengths: 各サンプルの実際の長さ
        Returns:
            log_probs: (time, batch, num_classes)
            output_lengths: 各サンプルの出力長
        """
        # PackedSequenceで可変長入力を効率的に処理
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        packed_output, _ = self.lstm(packed)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True
        )

        # CTC用にlog softmax適用
        logits = self.classifier(output)
        log_probs = F.log_softmax(logits, dim=-1)

        # CTCは(T, N, C)の形式を期待
        log_probs = log_probs.transpose(0, 1)

        return log_probs, output_lengths


# 訓練例
def train_ctc_model():
    """CTC ASRモデルの訓練例"""

    # モデルとCTC損失の初期化
    model = CTCASRModel(input_dim=80, hidden_dim=256, num_classes=29)
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # ダミーデータ(実際にはデータローダーから取得)
    batch_size = 4
    max_time = 100
    features = torch.randn(batch_size, max_time, 80)
    feature_lengths = torch.tensor([100, 95, 90, 85])

    # ターゲットテキスト(数値化済み)
    targets = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 10, 0], [11, 12, 0, 0]])
    target_lengths = torch.tensor([4, 3, 3, 2])

    model.train()

    # Forward pass
    log_probs, output_lengths = model(features, feature_lengths)

    # CTC損失計算
    loss = ctc_loss(log_probs, targets, output_lengths, target_lengths)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"CTC Loss: {loss.item():.4f}")

    return model

# 実行例
if __name__ == "__main__":
    model = train_ctc_model()
    print("✓ CTCモデルの訓練が完了しました")

1.3 CTCデコーディング

訓練されたCTCモデルからテキストを得るには、デコーディング処理が必要です。主な手法には、Greedy Decodingとビームサーチがあります。

import numpy as np
from collections import defaultdict

class CTCDecoder:
    """CTCデコーダー(Greedy & Beam Search)"""

    def __init__(self, labels, blank_idx=0):
        """
        Args:
            labels: 文字ラベルのリスト
            blank_idx: ブランク記号のインデックス
        """
        self.labels = labels
        self.blank_idx = blank_idx

    def greedy_decode(self, log_probs):
        """
        Greedy Decoding: 各時刻で最も確率の高いラベルを選択

        Args:
            log_probs: (time, num_classes) の対数確率
        Returns:
            decoded_text: デコードされたテキスト
        """
        # 各時刻で最も確率の高いインデックスを取得
        best_path = torch.argmax(log_probs, dim=-1)

        # 連続する重複とblankを除去
        decoded = []
        prev_idx = self.blank_idx

        for idx in best_path:
            idx = idx.item()
            if idx != self.blank_idx and idx != prev_idx:
                decoded.append(self.labels[idx])
            prev_idx = idx

        return ''.join(decoded)

    def beam_search_decode(self, log_probs, beam_width=10):
        """
        Beam Search Decoding: より正確なデコーディング

        Args:
            log_probs: (time, num_classes) の対数確率
            beam_width: ビーム幅
        Returns:
            decoded_text: デコードされたテキスト
        """
        T, C = log_probs.shape
        log_probs = log_probs.cpu().numpy()

        # ビーム: {sequence: probability}
        beams = {('', self.blank_idx): 0.0}  # (text, last_char): log_prob

        for t in range(T):
            new_beams = defaultdict(lambda: float('-inf'))

            for (text, last_char), log_prob in beams.items():
                for c in range(C):
                    new_log_prob = log_prob + log_probs[t, c]

                    if c == self.blank_idx:
                        # Blankの場合はテキストを変更しない
                        new_beams[(text, c)] = np.logaddexp(
                            new_beams[(text, c)], new_log_prob
                        )
                    else:
                        if c == last_char:
                            # 前の文字と同じ場合は繰り返さない
                            new_beams[(text, c)] = np.logaddexp(
                                new_beams[(text, c)], new_log_prob
                            )
                        else:
                            # 新しい文字を追加
                            new_text = text + self.labels[c]
                            new_beams[(new_text, c)] = np.logaddexp(
                                new_beams[(new_text, c)], new_log_prob
                            )

            # 上位beam_width個のビームを保持
            beams = dict(sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width])

        # 最も確率の高いビームを返す
        best_beam = max(beams.items(), key=lambda x: x[1])
        return best_beam[0][0]


# デコーディング例
def decode_example():
    """CTC デコーディングの実行例"""

    # アルファベット(0はblank)
    labels = ['-', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
              'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', "'"]

    decoder = CTCDecoder(labels, blank_idx=0)

    # ダミーの対数確率(実際はモデル出力)
    T = 50
    C = len(labels)
    log_probs = torch.randn(T, C)
    log_probs = F.log_softmax(log_probs, dim=-1)

    # Greedy Decoding
    greedy_text = decoder.greedy_decode(log_probs)
    print(f"Greedy Decode: {greedy_text}")

    # Beam Search Decoding
    beam_text = decoder.beam_search_decode(log_probs, beam_width=10)
    print(f"Beam Search Decode: {beam_text}")

decode_example()

2. Attention-based Models

2.1 Listen, Attend and Spell (LAS)

LASは、Encoder-DecoderアーキテクチャにAttentionメカニズムを組み合わせた音声認識モデルです。Encoderが音響特徴量を高レベル表現に変換し、DecoderがAttentionで必要な情報に注目しながらテキストを生成します。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ListenAttendSpell(nn.Module):
    """Listen, Attend and Spell モデル"""

    def __init__(self, input_dim=80, encoder_hidden=256, decoder_hidden=512,
                 vocab_size=29, num_layers=3):
        super(ListenAttendSpell, self).__init__()

        self.encoder = Listener(input_dim, encoder_hidden, num_layers)
        self.decoder = Speller(encoder_hidden * 2, decoder_hidden, vocab_size)

    def forward(self, inputs, input_lengths, targets=None, teacher_forcing_ratio=0.9):
        """
        Args:
            inputs: (batch, time, features)
            input_lengths: 各サンプルの長さ
            targets: (batch, target_len) デコーダーのターゲット
            teacher_forcing_ratio: Teacher Forcingの割合
        """
        # Encoder
        encoder_outputs, encoder_lengths = self.encoder(inputs, input_lengths)

        # Decoder
        if targets is not None:
            outputs = self.decoder(encoder_outputs, encoder_lengths, targets,
                                  teacher_forcing_ratio)
        else:
            outputs = self.decoder.inference(encoder_outputs, encoder_lengths)

        return outputs


class Listener(nn.Module):
    """Encoder: 音響特徴量を高レベル表現に変換"""

    def __init__(self, input_dim, hidden_dim, num_layers):
        super(Listener, self).__init__()

        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

        # Pyramidal LSTM: 時間方向を圧縮して計算効率を向上
        self.pyramid_lstm = nn.LSTM(
            hidden_dim * 4,  # 2つのフレームを結合 + 双方向
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

    def forward(self, x, lengths):
        """
        Args:
            x: (batch, time, features)
            lengths: 各サンプルの長さ
        """
        # 第1層LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output, _ = self.lstm(packed)
        output, lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

        # Pyramidal: 2フレームを1つに結合して時間を半分に
        batch, time, features = output.size()
        if time % 2 == 1:
            output = output[:, :-1, :]
            time -= 1

        output = output.reshape(batch, time // 2, features * 2)
        lengths = lengths // 2

        # 第2層LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            output, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output, _ = self.pyramid_lstm(packed)
        output, lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

        return output, lengths


class Speller(nn.Module):
    """Decoder: Attentionを用いてテキストを生成"""

    def __init__(self, encoder_dim, hidden_dim, vocab_size):
        super(Speller, self).__init__()

        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.lstm = nn.LSTMCell(hidden_dim + encoder_dim, hidden_dim)
        self.attention = BahdanauAttention(encoder_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim + encoder_dim, vocab_size)

    def forward(self, encoder_outputs, encoder_lengths, targets, teacher_forcing_ratio=0.9):
        """
        Teacher Forcing付き訓練
        """
        batch_size = encoder_outputs.size(0)
        max_len = targets.size(1)

        # 初期状態
        hidden = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)
        cell = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)

        # 開始トークン
        input_token = torch.zeros(batch_size, dtype=torch.long, device=encoder_outputs.device)

        outputs = []

        for t in range(max_len):
            # Embedding
            embedded = self.embedding(input_token)

            # Attention
            context, _ = self.attention(hidden, encoder_outputs, encoder_lengths)

            # LSTM
            lstm_input = torch.cat([embedded, context], dim=1)
            hidden, cell = self.lstm(lstm_input, (hidden, cell))

            # 出力
            output = self.classifier(torch.cat([hidden, context], dim=1))
            outputs.append(output)

            # Teacher Forcing
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
            if use_teacher_forcing:
                input_token = targets[:, t]
            else:
                input_token = output.argmax(dim=1)

        return torch.stack(outputs, dim=1)

    def inference(self, encoder_outputs, encoder_lengths, max_len=100):
        """推論時のデコーディング"""
        batch_size = encoder_outputs.size(0)

        hidden = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)
        cell = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)

        input_token = torch.zeros(batch_size, dtype=torch.long, device=encoder_outputs.device)

        outputs = []

        for t in range(max_len):
            embedded = self.embedding(input_token)
            context, _ = self.attention(hidden, encoder_outputs, encoder_lengths)

            lstm_input = torch.cat([embedded, context], dim=1)
            hidden, cell = self.lstm(lstm_input, (hidden, cell))

            output = self.classifier(torch.cat([hidden, context], dim=1))
            outputs.append(output)

            input_token = output.argmax(dim=1)

            # 終了トークンで停止
            if (input_token == self.vocab_size - 1).all():
                break

        return torch.stack(outputs, dim=1)


class BahdanauAttention(nn.Module):
    """Bahdanau Attentionメカニズム"""

    def __init__(self, encoder_dim, decoder_dim):
        super(BahdanauAttention, self).__init__()

        self.encoder_projection = nn.Linear(encoder_dim, decoder_dim)
        self.decoder_projection = nn.Linear(decoder_dim, decoder_dim)
        self.v = nn.Linear(decoder_dim, 1)

    def forward(self, decoder_hidden, encoder_outputs, encoder_lengths):
        """
        Args:
            decoder_hidden: (batch, decoder_dim)
            encoder_outputs: (batch, time, encoder_dim)
            encoder_lengths: (batch,)
        Returns:
            context: (batch, encoder_dim)
            attention_weights: (batch, time)
        """
        batch_size, time_steps, _ = encoder_outputs.size()

        # プロジェクション
        encoder_proj = self.encoder_projection(encoder_outputs)  # (batch, time, decoder_dim)
        decoder_proj = self.decoder_projection(decoder_hidden).unsqueeze(1)  # (batch, 1, decoder_dim)

        # Energy計算
        energy = self.v(torch.tanh(encoder_proj + decoder_proj)).squeeze(-1)  # (batch, time)

        # マスク(パディング部分を無視)
        mask = torch.arange(time_steps, device=encoder_outputs.device).expand(
            batch_size, time_steps
        ) < encoder_lengths.unsqueeze(1)

        energy = energy.masked_fill(~mask, float('-inf'))

        # Attention weights
        attention_weights = F.softmax(energy, dim=1)  # (batch, time)

        # Context vector
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attention_weights


# 使用例
def las_example():
    """LASモデルの使用例"""

    model = ListenAttendSpell(
        input_dim=80,
        encoder_hidden=256,
        decoder_hidden=512,
        vocab_size=29
    )

    # ダミーデータ
    batch_size = 4
    inputs = torch.randn(batch_size, 100, 80)
    input_lengths = torch.tensor([100, 95, 90, 85])
    targets = torch.randint(0, 29, (batch_size, 20))

    # 訓練モード
    outputs = model(inputs, input_lengths, targets, teacher_forcing_ratio=0.9)
    print(f"Training output shape: {outputs.shape}")

    # 推論モード
    model.eval()
    with torch.no_grad():
        predictions = model(inputs, input_lengths, targets=None)
        print(f"Inference output shape: {predictions.shape}")

las_example()

2.2 Transformer for ASR

TransformerアーキテクチャをASRに適用することで、並列処理が可能になり、長距離依存関係をより効果的に捉えることができます。

import torch
import torch.nn as nn
import math

class TransformerASR(nn.Module):
    """Transformer based ASR model"""

    def __init__(self, input_dim=80, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, vocab_size=29, dropout=0.1):
        super(TransformerASR, self).__init__()

        # 入力プロジェクション
        self.input_projection = nn.Linear(input_dim, d_model)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )

        # 出力層
        self.output_projection = nn.Linear(d_model, vocab_size)

        # パラメータ
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Embedding (デコーダー入力用)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        """
        Args:
            src: (batch, src_len, input_dim)
            tgt: (batch, tgt_len)
            src_mask: Encoder self-attention mask
            tgt_mask: Decoder self-attention mask (causal)
            src_key_padding_mask: (batch, src_len) パディングマスク
            tgt_key_padding_mask: (batch, tgt_len) パディングマスク
        """
        # Encoder入力の準備
        src = self.input_projection(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        # Decoder入力の準備
        tgt_embedded = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_embedded = self.pos_encoder(tgt_embedded)

        # Transformer
        output = self.transformer(
            src, tgt_embedded,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        # 出力プロジェクション
        output = self.output_projection(output)

        return output

    def generate_square_subsequent_mask(self, sz):
        """Causal mask生成(未来の情報を見ないようにする)"""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
        return mask


class PositionalEncoding(nn.Module):
    """Positional Encoding for Transformer"""

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Positional encodingの事前計算
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


# 訓練例
def train_transformer_asr():
    """Transformer ASRの訓練例"""

    model = TransformerASR(
        input_dim=80,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        vocab_size=29
    )

    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0はパディング
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98))

    # ダミーデータ
    batch_size = 8
    src = torch.randn(batch_size, 100, 80)
    tgt = torch.randint(1, 29, (batch_size, 30))

    # パディングマスク
    src_key_padding_mask = torch.zeros(batch_size, 100).bool()
    tgt_key_padding_mask = torch.zeros(batch_size, 30).bool()

    # Causal mask
    tgt_mask = model.generate_square_subsequent_mask(30).to(tgt.device)

    # Forward
    output = model(src, tgt[:, :-1], tgt_mask=tgt_mask,
                   src_key_padding_mask=src_key_padding_mask,
                   tgt_key_padding_mask=tgt_key_padding_mask[:, :-1])

    # Loss計算
    loss = criterion(output.reshape(-1, model.vocab_size), tgt[:, 1:].reshape(-1))

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Transformer ASR Loss: {loss.item():.4f}")

    return model

train_transformer_asr()

2.3 Joint CTC-Attention

CTCとAttentionを組み合わせることで、それぞれの長所を活かすことができます。CTCはアライメントの学習を助け、Attentionは文脈情報を活用します。

class JointCTCAttention(nn.Module):
    """CTC と Attention を組み合わせたハイブリッドモデル"""

    def __init__(self, input_dim=80, encoder_hidden=256, decoder_hidden=512,
                 vocab_size=29, num_layers=3, ctc_weight=0.3):
        super(JointCTCAttention, self).__init__()

        # 共有Encoder
        self.encoder = Listener(input_dim, encoder_hidden, num_layers)

        # CTC用の分類器
        self.ctc_classifier = nn.Linear(encoder_hidden * 2, vocab_size)

        # Attention-based Decoder
        self.decoder = Speller(encoder_hidden * 2, decoder_hidden, vocab_size)

        # CTC損失の重み
        self.ctc_weight = ctc_weight
        self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

    def forward(self, inputs, input_lengths, targets, target_lengths=None,
                teacher_forcing_ratio=0.9):
        """
        Args:
            inputs: (batch, time, features)
            input_lengths: 各サンプルの音響特徴量の長さ
            targets: (batch, target_len) テキストターゲット
            target_lengths: 各ターゲットの長さ(CTC用)
            teacher_forcing_ratio: Teacher Forcingの割合
        Returns:
            ctc_loss: CTC損失
            attention_loss: Attention損失
            combined_loss: 組み合わせた損失
        """
        # Encoder (共有)
        encoder_outputs, encoder_lengths = self.encoder(inputs, input_lengths)

        # CTC分岐
        ctc_logits = self.ctc_classifier(encoder_outputs)
        ctc_log_probs = F.log_softmax(ctc_logits, dim=-1)
        ctc_log_probs = ctc_log_probs.transpose(0, 1)  # (T, N, C)

        # Attention分岐
        attention_outputs = self.decoder(encoder_outputs, encoder_lengths,
                                        targets, teacher_forcing_ratio)

        return ctc_log_probs, encoder_lengths, attention_outputs

    def compute_loss(self, ctc_log_probs, encoder_lengths, attention_outputs,
                     targets, target_lengths):
        """
        損失の計算
        """
        # CTC損失
        ctc_loss = self.ctc_loss(ctc_log_probs, targets, encoder_lengths, target_lengths)

        # Attention損失 (Cross Entropy)
        attention_loss = F.cross_entropy(
            attention_outputs.reshape(-1, attention_outputs.size(-1)),
            targets.reshape(-1),
            ignore_index=0
        )

        # 組み合わせた損失
        combined_loss = self.ctc_weight * ctc_loss + (1 - self.ctc_weight) * attention_loss

        return ctc_loss, attention_loss, combined_loss

    def recognize(self, inputs, input_lengths, beam_width=10):
        """
        推論: CTCとAttentionの両方を使用したデコーディング
        """
        self.eval()
        with torch.no_grad():
            # Encoder
            encoder_outputs, encoder_lengths = self.encoder(inputs, input_lengths)

            # CTC分岐 (プレフィックスビームサーチ用)
            ctc_logits = self.ctc_classifier(encoder_outputs)
            ctc_probs = F.softmax(ctc_logits, dim=-1)

            # Attention分岐
            attention_outputs = self.decoder.inference(encoder_outputs, encoder_lengths)
            attention_probs = F.softmax(attention_outputs, dim=-1)

            # CTC と Attentionのスコアを組み合わせてデコード
            # (実際にはより複雑なビームサーチアルゴリズムを使用)
            combined_probs = (self.ctc_weight * ctc_probs[0] +
                            (1 - self.ctc_weight) * attention_probs)

            predictions = combined_probs.argmax(dim=-1)

            return predictions


# 訓練例
def train_joint_model():
    """Joint CTC-Attention モデルの訓練"""

    model = JointCTCAttention(
        input_dim=80,
        encoder_hidden=256,
        decoder_hidden=512,
        vocab_size=29,
        ctc_weight=0.3
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # ダミーデータ
    batch_size = 4
    inputs = torch.randn(batch_size, 100, 80)
    input_lengths = torch.tensor([100, 95, 90, 85])
    targets = torch.randint(1, 29, (batch_size, 20))
    target_lengths = torch.tensor([20, 18, 17, 15])

    # Forward
    ctc_log_probs, encoder_lengths, attention_outputs = model(
        inputs, input_lengths, targets, target_lengths
    )

    # 損失計算
    ctc_loss, attention_loss, combined_loss = model.compute_loss(
        ctc_log_probs, encoder_lengths, attention_outputs, targets, target_lengths
    )

    # Backward
    optimizer.zero_grad()
    combined_loss.backward()
    optimizer.step()

    print(f"CTC Loss: {ctc_loss.item():.4f}")
    print(f"Attention Loss: {attention_loss.item():.4f}")
    print(f"Combined Loss: {combined_loss.item():.4f}")

    return model

train_joint_model()

3. RNN-Transducer (RNN-T)

3.1 RNN-Tとは

RNN-Transducerは、ストリーミング音声認識に適したモデルです。音響モデル(Encoder)、言語モデル(Prediction Network)、そしてJoint Networkの3つのコンポーネントから構成されます。CTCと異なり、言語モデルを明示的に組み込むことができます。

RNN-Tの特徴

import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNTransducer(nn.Module):
    """RNN-Transducer モデル"""

    def __init__(self, input_dim=80, encoder_dim=256, pred_dim=256,
                 joint_dim=512, vocab_size=29, num_layers=3):
        """
        Args:
            input_dim: 入力特徴量の次元
            encoder_dim: Encoderの隠れ層次元
            pred_dim: Prediction Networkの隠れ層次元
            joint_dim: Joint Networkの隠れ層次元
            vocab_size: 語彙サイズ(blank含む)
        """
        super(RNNTransducer, self).__init__()

        # Encoder (Transcription Network)
        self.encoder = nn.LSTM(
            input_dim,
            encoder_dim,
            num_layers,
            batch_first=True,
            bidirectional=False,  # ストリーミング用に単方向
            dropout=0.2
        )

        # Prediction Network (Language Model)
        self.embedding = nn.Embedding(vocab_size, pred_dim)
        self.prediction = nn.LSTM(
            pred_dim,
            pred_dim,
            num_layers,
            batch_first=True,
            dropout=0.2
        )

        # Joint Network
        self.joint = JointNetwork(encoder_dim, pred_dim, joint_dim, vocab_size)

        self.vocab_size = vocab_size
        self.blank_idx = 0

    def forward(self, inputs, input_lengths, targets, target_lengths):
        """
        Args:
            inputs: (batch, time, features) 音響特徴量
            input_lengths: 各サンプルの入力長
            targets: (batch, target_len) ターゲットラベル
            target_lengths: 各サンプルのターゲット長
        Returns:
            joint_output: (batch, time, target_len+1, vocab_size)
        """
        # Encoder
        encoder_out, _ = self.encoder(inputs)  # (batch, time, encoder_dim)

        # Prediction Network
        # 開始トークンを追加
        batch_size = targets.size(0)
        start_tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=targets.device)
        pred_input = torch.cat([start_tokens, targets], dim=1)

        pred_embedded = self.embedding(pred_input)
        pred_out, _ = self.prediction(pred_embedded)  # (batch, target_len+1, pred_dim)

        # Joint Network
        joint_output = self.joint(encoder_out, pred_out)

        return joint_output

    def greedy_decode(self, inputs, input_lengths, max_len=100):
        """
        Greedy decoding for inference
        """
        self.eval()
        with torch.no_grad():
            batch_size = inputs.size(0)

            # Encoder
            encoder_out, _ = self.encoder(inputs)
            time_steps = encoder_out.size(1)

            # 初期化
            predictions = []
            pred_hidden = None

            # 開始トークン
            pred_input = torch.zeros(batch_size, 1, dtype=torch.long, device=inputs.device)

            for t in range(time_steps):
                # 現在のエンコーダー出力
                enc_t = encoder_out[:, t:t+1, :]  # (batch, 1, encoder_dim)

                # Prediction Network
                pred_embedded = self.embedding(pred_input)
                pred_out, pred_hidden = self.prediction(pred_embedded, pred_hidden)

                # Joint Network
                joint_out = self.joint(enc_t, pred_out)  # (batch, 1, 1, vocab_size)

                # 最も確率の高いトークンを選択
                prob = F.softmax(joint_out.squeeze(1).squeeze(1), dim=-1)
                pred_token = prob.argmax(dim=-1)

                # Blankでない場合のみ出力に追加
                if pred_token.item() != self.blank_idx:
                    predictions.append(pred_token.item())
                    pred_input = pred_token.unsqueeze(1)

                if len(predictions) >= max_len:
                    break

            return predictions


class JointNetwork(nn.Module):
    """Joint Network: EncoderとPrediction Networkの出力を結合"""

    def __init__(self, encoder_dim, pred_dim, joint_dim, vocab_size):
        super(JointNetwork, self).__init__()

        self.encoder_proj = nn.Linear(encoder_dim, joint_dim)
        self.pred_proj = nn.Linear(pred_dim, joint_dim)
        self.output_proj = nn.Linear(joint_dim, vocab_size)

    def forward(self, encoder_out, pred_out):
        """
        Args:
            encoder_out: (batch, time, encoder_dim)
            pred_out: (batch, target_len, pred_dim)
        Returns:
            joint_out: (batch, time, target_len, vocab_size)
        """
        # プロジェクション
        enc_proj = self.encoder_proj(encoder_out)  # (batch, time, joint_dim)
        pred_proj = self.pred_proj(pred_out)  # (batch, target_len, joint_dim)

        # ブロードキャストして加算
        # (batch, time, 1, joint_dim) + (batch, 1, target_len, joint_dim)
        joint = torch.tanh(
            enc_proj.unsqueeze(2) + pred_proj.unsqueeze(1)
        )  # (batch, time, target_len, joint_dim)

        # 出力プロジェクション
        output = self.output_proj(joint)  # (batch, time, target_len, vocab_size)

        return output


# RNN-T Loss (簡易版)
class RNNTLoss(nn.Module):
    """RNN-T損失関数 (Forward-Backward アルゴリズム)"""

    def __init__(self, blank_idx=0):
        super(RNNTLoss, self).__init__()
        self.blank_idx = blank_idx

    def forward(self, logits, targets, input_lengths, target_lengths):
        """
        Args:
            logits: (batch, time, target_len+1, vocab_size)
            targets: (batch, target_len)
            input_lengths: (batch,)
            target_lengths: (batch,)
        """
        # PyTorchのtorchaudio.functional.rnnt_lossを使用
        # ここでは簡易的にCTC損失で代用
        batch_size, time, _, vocab_size = logits.size()

        # Greedy path approximation
        log_probs = F.log_softmax(logits, dim=-1)

        # 各時刻でのblankと非blankの確率を計算
        loss = 0
        for b in range(batch_size):
            for t in range(input_lengths[b]):
                for u in range(target_lengths[b] + 1):
                    if u < target_lengths[b]:
                        # 非blank: 正しいラベルの確率
                        target_label = targets[b, u]
                        loss -= log_probs[b, t, u, target_label]
                    else:
                        # blank
                        loss -= log_probs[b, t, u, self.blank_idx]

        return loss / batch_size


# 使用例
def train_rnnt():
    """RNN-T モデルの訓練例"""

    model = RNNTransducer(
        input_dim=80,
        encoder_dim=256,
        pred_dim=256,
        joint_dim=512,
        vocab_size=29
    )

    criterion = RNNTLoss(blank_idx=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # ダミーデータ
    batch_size = 4
    inputs = torch.randn(batch_size, 100, 80)
    input_lengths = torch.tensor([100, 95, 90, 85])
    targets = torch.randint(1, 29, (batch_size, 20))
    target_lengths = torch.tensor([20, 18, 17, 15])

    # Forward
    logits = model(inputs, input_lengths, targets, target_lengths)

    # Loss
    loss = criterion(logits, targets, input_lengths, target_lengths)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"RNN-T Loss: {loss.item():.4f}")

    # 推論例
    predictions = model.greedy_decode(inputs[:1], input_lengths[:1])
    print(f"Predicted tokens: {predictions}")

    return model

train_rnnt()

3.2 CTCとRNN-Tの比較

特徴 CTC RNN-Transducer
言語モデル 外部LMが必要 Prediction Networkで内部統合
ストリーミング 可能(単方向LSTM使用時) 設計上ストリーミング対応
計算コスト 低い やや高い(Joint Network)
精度 中程度 高い(言語モデル統合効果)
訓練の安定性 比較的安定 やや不安定な場合あり

4. Whisper

4.1 OpenAI Whisperとは

Whisperは、OpenAIが開発した多言語音声認識モデルです。68万時間の多言語データで訓練されており、99言語の音声認識、音声翻訳、言語識別タスクに対応しています。

Whisperの特徴

import whisper
import torch
import numpy as np

# Whisperモデルの基本的な使用
def basic_whisper_usage():
    """Whisperの基本的な使い方"""

    # モデルのロード (base, small, medium, large から選択)
    model = whisper.load_model("base")

    print(f"モデルサイズ: {model.dims}")
    print(f"対応言語数: {len(whisper.tokenizer.LANGUAGES)}")

    # 音声ファイルの読み込みと文字起こし
    # audio = whisper.load_audio("audio.mp3")
    # audio = whisper.pad_or_trim(audio)

    # ダミー音声(実際にはファイルから読み込み)
    audio = np.random.randn(16000 * 10).astype(np.float32)  # 10秒の音声

    # メルスペクトログラムに変換
    mel = whisper.log_mel_spectrogram(torch.from_numpy(audio)).to(model.device)

    # 言語検出
    _, probs = model.detect_language(mel)
    detected_language = max(probs, key=probs.get)
    print(f"検出された言語: {detected_language} (確率: {probs[detected_language]:.2f})")

    # オプション設定
    options = whisper.DecodingOptions(
        language="ja",  # 日本語を指定
        task="transcribe",  # transcribe or translate
        fp16=False  # FP16を使用するか
    )

    # デコーディング
    result = whisper.decode(model, mel, options)

    print(f"文字起こし結果: {result.text}")
    print(f"平均対数確率: {result.avg_logprob:.4f}")
    print(f"圧縮率: {result.compression_ratio:.2f}")

    return model, result


# より高レベルなAPI
def transcribe_audio_file(audio_path, model_size="base"):
    """
    音声ファイルを文字起こし

    Args:
        audio_path: 音声ファイルのパス
        model_size: モデルサイズ (tiny, base, small, medium, large)
    """
    # モデルロード
    model = whisper.load_model(model_size)

    # 文字起こし
    result = model.transcribe(
        audio_path,
        language="ja",  # 日本語
        task="transcribe",
        verbose=True,  # 進捗表示
        temperature=0.0,  # 温度(多様性の制御)
        best_of=5,  # ビームサーチの候補数
        beam_size=5,  # ビーム幅
        patience=1.0,  # ビームサーチの忍耐度
        length_penalty=1.0,  # 長さペナルティ
        compression_ratio_threshold=2.4,  # 圧縮率の閾値
        logprob_threshold=-1.0,  # 対数確率の閾値
        no_speech_threshold=0.6  # 無音判定の閾値
    )

    # 結果の表示
    print("=" * 50)
    print("文字起こし結果:")
    print("=" * 50)
    print(result["text"])
    print()

    # セグメント単位の結果
    print("セグメント詳細:")
    for segment in result["segments"]:
        start = segment["start"]
        end = segment["end"]
        text = segment["text"]
        print(f"[{start:.2f}s - {end:.2f}s] {text}")

    return result


# タイムスタンプ付き文字起こし
def transcribe_with_timestamps(audio_path):
    """タイムスタンプ付きで文字起こし"""

    model = whisper.load_model("base")

    # word_timestamps=Trueで単語レベルのタイムスタンプを取得
    result = model.transcribe(
        audio_path,
        language="ja",
        word_timestamps=True  # 単語レベルのタイムスタンプ
    )

    # 単語レベルの結果を表示
    for segment in result["segments"]:
        if "words" in segment:
            for word_info in segment["words"]:
                word = word_info["word"]
                start = word_info["start"]
                end = word_info["end"]
                probability = word_info.get("probability", 0.0)
                print(f"{word:15s} [{start:6.2f}s - {end:6.2f}s] (確率: {probability:.3f})")

    return result


# 多言語音声の処理
def multilingual_transcription(audio_path):
    """多言語音声の処理"""

    model = whisper.load_model("medium")  # 多言語には中型以上を推奨

    # 言語を自動検出
    result = model.transcribe(
        audio_path,
        task="transcribe",
        language=None  # 自動検出
    )

    detected_language = result["language"]
    print(f"検出された言語: {detected_language}")
    print(f"文字起こし: {result['text']}")

    # 英語に翻訳
    translation = model.transcribe(
        audio_path,
        task="translate",  # 英語に翻訳
        language=detected_language
    )

    print(f"英訳: {translation['text']}")

    return result, translation


# 実行例
if __name__ == "__main__":
    print("Whisper使用例")
    print("=" * 50)

    # 基本的な使用法
    model, result = basic_whisper_usage()
    print("✓ 基本的な文字起こしが完了しました")

    # 注: 実際の音声ファイルを使う場合は以下のコメントを外す
    # result = transcribe_audio_file("sample.mp3", model_size="base")
    # result = transcribe_with_timestamps("sample.mp3")
    # result, translation = multilingual_transcription("sample.mp3")

4.2 Whisperのアーキテクチャ

WhisperはEncoder-Decoderアーキテクチャを採用しています。Encoderは音響特徴量を処理し、Decoderはテキストを生成します。どちらもTransformerベースです。

graph LR A[音声入力] --> B[メルスペクトログラム] B --> C[Encoder
Transformer] C --> D[音響表現] D --> E[Decoder
Transformer] E --> F[テキスト出力] style A fill:#e1f5ff style F fill:#e1f5ff style C fill:#fff4e1 style E fill:#fff4e1
import torch
import torch.nn as nn
from typing import Optional

class WhisperArchitecture(nn.Module):
    """
    Whisperのアーキテクチャ概要
    (実際のWhisperはより複雑ですが、主要な構造を示します)
    """

    def __init__(self,
                 n_mels=80,
                 n_audio_ctx=1500,
                 n_audio_state=512,
                 n_audio_head=8,
                 n_audio_layer=6,
                 n_vocab=51865,
                 n_text_ctx=448,
                 n_text_state=512,
                 n_text_head=8,
                 n_text_layer=6):
        """
        Args:
            n_mels: メルフィルタバンク数
            n_audio_ctx: 音響コンテキスト長
            n_audio_state: Encoderの状態次元
            n_audio_head: Encoderのアテンションヘッド数
            n_audio_layer: Encoderの層数
            n_vocab: 語彙サイズ
            n_text_ctx: テキストコンテキスト長
            n_text_state: Decoderの状態次元
            n_text_head: Decoderのアテンションヘッド数
            n_text_layer: Decoderの層数
        """
        super().__init__()

        # Encoder: 音響特徴量を処理
        self.encoder = AudioEncoder(
            n_mels=n_mels,
            n_ctx=n_audio_ctx,
            n_state=n_audio_state,
            n_head=n_audio_head,
            n_layer=n_audio_layer
        )

        # Decoder: テキストを生成
        self.decoder = TextDecoder(
            n_vocab=n_vocab,
            n_ctx=n_text_ctx,
            n_state=n_text_state,
            n_head=n_text_head,
            n_layer=n_text_layer
        )

    def forward(self, mel, tokens):
        """
        Args:
            mel: (batch, n_mels, time) メルスペクトログラム
            tokens: (batch, seq_len) トークン
        """
        # Encoder
        audio_features = self.encoder(mel)

        # Decoder
        logits = self.decoder(tokens, audio_features)

        return logits


class AudioEncoder(nn.Module):
    """Whisper Audio Encoder"""

    def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
        super().__init__()

        # 畳み込み層で特徴抽出
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)

        # Positional embedding
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        # Transformer layers
        self.blocks = nn.ModuleList([
            TransformerBlock(n_state, n_head) for _ in range(n_layer)
        ])

        self.ln_post = nn.LayerNorm(n_state)

    def forward(self, x):
        """
        Args:
            x: (batch, n_mels, time)
        Returns:
            (batch, time//2, n_state)
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)  # (batch, time, n_state)

        # Positional embedding
        x = x + self.positional_embedding[:x.size(1)]

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x


class TextDecoder(nn.Module):
    """Whisper Text Decoder"""

    def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
        super().__init__()

        # Token embedding
        self.token_embedding = nn.Embedding(n_vocab, n_state)

        # Positional embedding
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        # Transformer layers (with cross-attention)
        self.blocks = nn.ModuleList([
            DecoderBlock(n_state, n_head) for _ in range(n_layer)
        ])

        self.ln = nn.LayerNorm(n_state)

    def forward(self, tokens, audio_features):
        """
        Args:
            tokens: (batch, seq_len)
            audio_features: (batch, audio_len, n_state)
        Returns:
            (batch, seq_len, n_vocab)
        """
        x = self.token_embedding(tokens)
        x = x + self.positional_embedding[:x.size(1)]

        for block in self.blocks:
            x = block(x, audio_features)

        x = self.ln(x)

        # Weight tying: token embeddingを再利用
        logits = x @ self.token_embedding.weight.T

        return logits


class TransformerBlock(nn.Module):
    """Transformer Encoder Block"""

    def __init__(self, n_state, n_head):
        super().__init__()

        self.attn = nn.MultiheadAttention(n_state, n_head, batch_first=True)
        self.attn_ln = nn.LayerNorm(n_state)

        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_state * 4),
            nn.GELU(),
            nn.Linear(n_state * 4, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attn(x, x, x)
        x = self.attn_ln(x + attn_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.mlp_ln(x + mlp_out)

        return x


class DecoderBlock(nn.Module):
    """Transformer Decoder Block (with cross-attention)"""

    def __init__(self, n_state, n_head):
        super().__init__()

        # Self-attention
        self.self_attn = nn.MultiheadAttention(n_state, n_head, batch_first=True)
        self.self_attn_ln = nn.LayerNorm(n_state)

        # Cross-attention
        self.cross_attn = nn.MultiheadAttention(n_state, n_head, batch_first=True)
        self.cross_attn_ln = nn.LayerNorm(n_state)

        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_state * 4),
            nn.GELU(),
            nn.Linear(n_state * 4, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x, audio_features):
        # Self-attention (causal mask)
        attn_out, _ = self.self_attn(x, x, x, need_weights=False, is_causal=True)
        x = self.self_attn_ln(x + attn_out)

        # Cross-attention
        cross_out, _ = self.cross_attn(x, audio_features, audio_features)
        x = self.cross_attn_ln(x + cross_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.mlp_ln(x + mlp_out)

        return x


# アーキテクチャの概要表示
def show_architecture_info():
    """Whisperの各モデルサイズの情報"""

    models_info = {
        "tiny": {
            "parameters": "39M",
            "n_audio_layer": 4,
            "n_text_layer": 4,
            "n_state": 384,
            "n_head": 6
        },
        "base": {
            "parameters": "74M",
            "n_audio_layer": 6,
            "n_text_layer": 6,
            "n_state": 512,
            "n_head": 8
        },
        "small": {
            "parameters": "244M",
            "n_audio_layer": 12,
            "n_text_layer": 12,
            "n_state": 768,
            "n_head": 12
        },
        "medium": {
            "parameters": "769M",
            "n_audio_layer": 24,
            "n_text_layer": 24,
            "n_state": 1024,
            "n_head": 16
        },
        "large": {
            "parameters": "1550M",
            "n_audio_layer": 32,
            "n_text_layer": 32,
            "n_state": 1280,
            "n_head": 20
        }
    }

    print("Whisper モデルサイズ比較")
    print("=" * 70)
    print(f"{'Model':<10} {'Parameters':<12} {'Layers':<10} {'State':<10} {'Heads':<10}")
    print("=" * 70)

    for model, info in models_info.items():
        print(f"{model:<10} {info['parameters']:<12} "
              f"{info['n_audio_layer']}/{info['n_text_layer']:<10} "
              f"{info['n_state']:<10} {info['n_head']:<10}")

    print()
    print("推奨用途:")
    print("- tiny/base: リアルタイム処理、リソース制約環境")
    print("- small: バランス型、一般的な用途")
    print("- medium/large: 高精度が必要な場合、多言語処理")

show_architecture_info()

5. 実践的なASRシステム

5.1 Whisperのファインチューニング

Whisperを特定のドメインや言語にファインチューニングすることで、さらに高い精度を実現できます。

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, Audio
import torch

def finetune_whisper_japanese():
    """
    日本語音声認識のためのWhisperファインチューニング
    """

    # モデルとプロセッサのロード
    model_name = "openai/whisper-small"
    processor = WhisperProcessor.from_pretrained(model_name, language="Japanese", task="transcribe")
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    # 日本語に特化した設定
    model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
        language="Japanese",
        task="transcribe"
    )
    model.config.suppress_tokens = []

    # データセットの準備(例: Common Voice Japanese)
    # dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ja", split="train[:100]")

    # ダミーデータセット(実際には上記のようなデータセットを使用)
    def prepare_dataset(batch):
        """データセットの前処理"""
        # 音声を16kHzにリサンプリング
        audio = batch["audio"]

        # メルスペクトログラムに変換
        batch["input_features"] = processor(
            audio["array"],
            sampling_rate=audio["sampling_rate"],
            return_tensors="pt"
        ).input_features[0]

        # ラベルの準備
        batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids

        return batch

    # データセットに前処理を適用
    # dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)

    # データコレーター
    from dataclasses import dataclass
    from typing import Any, Dict, List, Union

    @dataclass
    class DataCollatorSpeechSeq2SeqWithPadding:
        """音声データのコレーター"""

        processor: Any

        def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
            # 入力特徴量をパディング
            input_features = [{"input_features": feature["input_features"]} for feature in features]
            batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

            # ラベルをパディング
            label_features = [{"input_ids": feature["labels"]} for feature in features]
            labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

            # パディングトークンを-100に置き換え(損失計算で無視される)
            labels = labels_batch["input_ids"].masked_fill(
                labels_batch.attention_mask.ne(1), -100
            )

            # bos tokenが存在する場合は削除
            if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
                labels = labels[:, 1:]

            batch["labels"] = labels

            return batch

    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # 評価メトリクス
    import evaluate

    metric = evaluate.load("wer")  # Word Error Rate

    def compute_metrics(pred):
        """評価メトリクスの計算"""
        pred_ids = pred.predictions
        label_ids = pred.label_ids

        # -100を pad_token_id に置き換え
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

        # デコード
        pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        # WER計算
        wer = metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # 訓練設定
    training_args = Seq2SeqTrainingArguments(
        output_dir="./whisper-japanese",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        learning_rate=1e-5,
        warmup_steps=500,
        num_train_epochs=3,
        evaluation_strategy="steps",
        eval_steps=1000,
        save_steps=1000,
        logging_steps=100,
        generation_max_length=225,
        predict_with_generate=True,
        fp16=True,  # 混合精度訓練
        push_to_hub=False,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
    )

    # トレーナーの初期化
    # trainer = Seq2SeqTrainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=dataset["train"],
    #     eval_dataset=dataset["test"],
    #     data_collator=data_collator,
    #     compute_metrics=compute_metrics,
    #     tokenizer=processor.feature_extractor,
    # )

    # 訓練開始
    # trainer.train()

    # モデルの保存
    # model.save_pretrained("./whisper-japanese-finetuned")
    # processor.save_pretrained("./whisper-japanese-finetuned")

    print("✓ ファインチューニング設定が完了しました")
    print("実際の訓練にはCommon Voiceなどのデータセットを使用してください")

    return model, processor

# 実行例
model, processor = finetune_whisper_japanese()

5.2 リアルタイム音声認識アプリケーション

マイクからの音声入力をリアルタイムで文字起こしするアプリケーションを構築します。

import pyaudio
import numpy as np
import whisper
import queue
import threading
from collections import deque

class RealtimeASR:
    """リアルタイム音声認識システム"""

    def __init__(self, model_name="base", language="ja"):
        """
        Args:
            model_name: Whisperモデルのサイズ
            language: 認識言語
        """
        # Whisperモデルのロード
        print(f"Whisperモデル '{model_name}' を読み込み中...")
        self.model = whisper.load_model(model_name)
        self.language = language

        # 音声設定
        self.RATE = 16000  # サンプリングレート
        self.CHUNK = 1024  # バッファサイズ
        self.CHANNELS = 1  # モノラル
        self.FORMAT = pyaudio.paInt16

        # 音声バッファ
        self.audio_queue = queue.Queue()
        self.audio_buffer = deque(maxlen=30)  # 30秒分のバッファ

        # PyAudioの初期化
        self.audio = pyaudio.PyAudio()

        # 状態管理
        self.is_recording = False
        self.transcription_thread = None

        print("✓ リアルタイムASRの初期化が完了しました")

    def audio_callback(self, in_data, frame_count, time_info, status):
        """音声入力コールバック"""
        if self.is_recording:
            # 音声データをキューに追加
            audio_data = np.frombuffer(in_data, dtype=np.int16)
            self.audio_queue.put(audio_data)

        return (in_data, pyaudio.paContinue)

    def start_recording(self):
        """録音開始"""
        self.is_recording = True

        # 音声ストリームを開く
        self.stream = self.audio.open(
            format=self.FORMAT,
            channels=self.CHANNELS,
            rate=self.RATE,
            input=True,
            frames_per_buffer=self.CHUNK,
            stream_callback=self.audio_callback
        )

        self.stream.start_stream()

        # 文字起こしスレッドを開始
        self.transcription_thread = threading.Thread(target=self.transcribe_loop)
        self.transcription_thread.start()

        print("🎤 録音を開始しました...")

    def stop_recording(self):
        """録音停止"""
        self.is_recording = False

        if hasattr(self, 'stream'):
            self.stream.stop_stream()
            self.stream.close()

        if self.transcription_thread:
            self.transcription_thread.join()

        print("⏹️  録音を停止しました")

    def transcribe_loop(self):
        """文字起こしループ"""
        print("📝 文字起こしを開始...")

        while self.is_recording:
            # 音声データを収集(1秒分)
            audio_chunks = []
            for _ in range(int(self.RATE / self.CHUNK)):
                try:
                    chunk = self.audio_queue.get(timeout=0.1)
                    audio_chunks.append(chunk)
                except queue.Empty:
                    continue

            if not audio_chunks:
                continue

            # 音声データを結合
            audio_data = np.concatenate(audio_chunks)
            self.audio_buffer.append(audio_data)

            # バッファから音声を取得(5秒分)
            if len(self.audio_buffer) >= 5:
                # 最新の5秒分を使用
                audio_segment = np.concatenate(list(self.audio_buffer)[-5:])

                # 正規化
                audio_segment = audio_segment.astype(np.float32) / 32768.0

                # 文字起こし
                try:
                    result = self.model.transcribe(
                        audio_segment,
                        language=self.language,
                        task="transcribe",
                        fp16=False,
                        temperature=0.0,
                        no_speech_threshold=0.6
                    )

                    text = result["text"].strip()
                    if text:
                        print(f"認識結果: {text}")

                except Exception as e:
                    print(f"エラー: {e}")

    def __del__(self):
        """クリーンアップ"""
        if hasattr(self, 'audio'):
            self.audio.terminate()


# 使用例
def realtime_asr_demo():
    """リアルタイムASRのデモ"""

    # ASRシステムの初期化
    asr = RealtimeASR(model_name="base", language="ja")

    try:
        # 録音開始
        asr.start_recording()

        # 10秒間録音
        import time
        print("10秒間話してください...")
        time.sleep(10)

        # 録音停止
        asr.stop_recording()

    except KeyboardInterrupt:
        print("\n中断されました")
        asr.stop_recording()

    print("✓ デモが完了しました")


# バッチ処理版(ファイルから)
def batch_transcribe_with_speaker_diarization(audio_file):
    """
    話者分離を含む音声認識
    (pyannote.audioなどのライブラリを使用)
    """
    import whisper

    # Whisperで文字起こし
    model = whisper.load_model("medium")
    result = model.transcribe(
        audio_file,
        language="ja",
        word_timestamps=True
    )

    # 話者分離(ダミー実装)
    # 実際にはpyannote.audioなどを使用
    print("=" * 50)
    print("文字起こし結果(話者付き):")
    print("=" * 50)

    current_speaker = "Speaker 1"
    for i, segment in enumerate(result["segments"]):
        # 簡易的な話者切り替え判定(実際には話者分離モデルを使用)
        if i > 0 and segment["start"] - result["segments"][i-1]["end"] > 2.0:
            current_speaker = "Speaker 2" if current_speaker == "Speaker 1" else "Speaker 1"

        start = segment["start"]
        end = segment["end"]
        text = segment["text"]

        print(f"[{current_speaker}] [{start:.2f}s - {end:.2f}s]")
        print(f"  {text}")
        print()

    return result


# 注: 実際の実行には pyaudio のインストールが必要
# pip install pyaudio
#
# macOSの場合:
# brew install portaudio
# pip install pyaudio

print("リアルタイムASRシステムの実装例を表示しました")
print("実行には 'pyaudio' のインストールが必要です")

5.3 完全なASRアプリケーション

Webインターフェースを持つ完全な音声認識アプリケーションを構築します。

import gradio as gr
import whisper
import numpy as np
from pathlib import Path

class ASRApplication:
    """Webベースの音声認識アプリケーション"""

    def __init__(self):
        """アプリケーションの初期化"""
        self.models = {}
        self.current_model = None

        # 利用可能なモデル
        self.available_models = {
            "tiny": "最速(39M parameters)",
            "base": "高速(74M parameters)",
            "small": "バランス型(244M parameters)",
            "medium": "高精度(769M parameters)",
            "large": "最高精度(1550M parameters)"
        }

        # 対応言語
        self.languages = {
            "自動検出": None,
            "日本語": "ja",
            "英語": "en",
            "中国語": "zh",
            "韓国語": "ko",
            "スペイン語": "es",
            "フランス語": "fr",
            "ドイツ語": "de"
        }

    def load_model(self, model_name):
        """モデルのロード(キャッシュ付き)"""
        if model_name not in self.models:
            print(f"モデル '{model_name}' を読み込み中...")
            self.models[model_name] = whisper.load_model(model_name)
            print(f"✓ モデル '{model_name}' の読み込みが完了しました")

        self.current_model = self.models[model_name]
        return self.current_model

    def transcribe_audio(self, audio_file, model_name, language, task,
                        include_timestamps, beam_size, temperature):
        """
        音声ファイルを文字起こし

        Args:
            audio_file: 音声ファイルのパス
            model_name: 使用するモデル
            language: 認識言語
            task: transcribe or translate
            include_timestamps: タイムスタンプを含めるか
            beam_size: ビームサーチの幅
            temperature: サンプリング温度
        """
        if audio_file is None:
            return "音声ファイルをアップロードしてください", ""

        try:
            # モデルのロード
            model = self.load_model(model_name)

            # 文字起こし
            result = model.transcribe(
                audio_file,
                language=self.languages.get(language),
                task=task,
                beam_size=beam_size,
                temperature=temperature,
                word_timestamps=include_timestamps
            )

            # 基本的な文字起こし結果
            transcription = result["text"]

            # 詳細情報
            details = self._format_details(result, include_timestamps)

            return transcription, details

        except Exception as e:
            return f"エラーが発生しました: {str(e)}", ""

    def _format_details(self, result, include_timestamps):
        """詳細情報のフォーマット"""
        details = []

        # 検出言語
        if "language" in result:
            details.append(f"検出言語: {result['language']}")

        # セグメント情報
        if include_timestamps and "segments" in result:
            details.append("\n" + "=" * 50)
            details.append("セグメント詳細:")
            details.append("=" * 50)

            for i, segment in enumerate(result["segments"], 1):
                start = segment["start"]
                end = segment["end"]
                text = segment["text"]

                details.append(f"\n[{i}] {start:.2f}s - {end:.2f}s")
                details.append(f"    {text}")

                # 単語レベルのタイムスタンプ
                if "words" in segment:
                    details.append("    単語:")
                    for word_info in segment["words"]:
                        word = word_info["word"]
                        w_start = word_info["start"]
                        w_end = word_info["end"]
                        details.append(f"      - {word:20s} [{w_start:.2f}s - {w_end:.2f}s]")

        return "\n".join(details)

    def create_interface(self):
        """Gradio インターフェースの作成"""

        with gr.Blocks(title="AI音声認識システム") as interface:
            gr.Markdown(
                """
                # 🎙️ AI音声認識システム

                Whisperを使用した高精度な音声認識システムです。
                音声ファイルをアップロードするか、マイクで録音して文字起こしを行います。
                """
            )

            with gr.Row():
                with gr.Column(scale=1):
                    # 入力コントロール
                    audio_input = gr.Audio(
                        sources=["upload", "microphone"],
                        type="filepath",
                        label="音声入力"
                    )

                    model_selector = gr.Dropdown(
                        choices=list(self.available_models.keys()),
                        value="base",
                        label="モデル選択",
                        info="精度と速度のトレードオフを選択"
                    )

                    language_selector = gr.Dropdown(
                        choices=list(self.languages.keys()),
                        value="自動検出",
                        label="言語"
                    )

                    task_selector = gr.Radio(
                        choices=["transcribe", "translate"],
                        value="transcribe",
                        label="タスク",
                        info="transcribe: 同じ言語で文字起こし / translate: 英語に翻訳"
                    )

                    with gr.Accordion("詳細設定", open=False):
                        include_timestamps = gr.Checkbox(
                            label="タイムスタンプを含める",
                            value=True
                        )

                        beam_size = gr.Slider(
                            minimum=1,
                            maximum=10,
                            value=5,
                            step=1,
                            label="ビームサイズ",
                            info="大きいほど精度向上、計算時間増加"
                        )

                        temperature = gr.Slider(
                            minimum=0.0,
                            maximum=1.0,
                            value=0.0,
                            step=0.1,
                            label="温度",
                            info="0: 決定的、>0: ランダム性あり"
                        )

                    transcribe_btn = gr.Button("文字起こし開始", variant="primary")

                with gr.Column(scale=2):
                    # 出力
                    transcription_output = gr.Textbox(
                        label="文字起こし結果",
                        lines=5,
                        max_lines=10
                    )

                    details_output = gr.Textbox(
                        label="詳細情報",
                        lines=15,
                        max_lines=30
                    )

            # イベントハンドラ
            transcribe_btn.click(
                fn=self.transcribe_audio,
                inputs=[
                    audio_input,
                    model_selector,
                    language_selector,
                    task_selector,
                    include_timestamps,
                    beam_size,
                    temperature
                ],
                outputs=[transcription_output, details_output]
            )

            # 使用例
            gr.Markdown(
                """
                ## 使い方

                1. **音声入力**: ファイルをアップロードするか、マイクボタンで録音
                2. **モデル選択**: 用途に応じてモデルサイズを選択
                   - リアルタイム処理: tiny/base
                   - 一般的な用途: small
                   - 高精度が必要: medium/large
                3. **言語選択**: 自動検出または特定の言語を選択
                4. **文字起こし開始**: ボタンをクリックして処理開始

                ## モデル情報

                | モデル | パラメータ数 | 速度 | 精度 | 推奨用途 |
                |--------|--------------|------|------|----------|
                | tiny   | 39M          | ★★★★★ | ★★☆☆☆ | リアルタイム処理 |
                | base   | 74M          | ★★★★☆ | ★★★☆☆ | 高速処理 |
                | small  | 244M         | ★★★☆☆ | ★★★★☆ | バランス型 |
                | medium | 769M         | ★★☆☆☆ | ★★★★★ | 高精度 |
                | large  | 1550M        | ★☆☆☆☆ | ★★★★★ | 最高精度 |
                """
            )

        return interface

    def launch(self, share=False):
        """アプリケーションの起動"""
        interface = self.create_interface()
        interface.launch(share=share)


# アプリケーションの実行
if __name__ == "__main__":
    app = ASRApplication()

    print("=" * 50)
    print("AI音声認識システムを起動中...")
    print("=" * 50)

    # アプリケーション起動
    # app.launch(share=False)

    # 注: Gradioのインストールが必要
    # pip install gradio

    print("✓ アプリケーションの設定が完了しました")
    print("実行には 'gradio' のインストールが必要です")
    print("インストール: pip install gradio")

練習問題

問題1: CTC損失の理解

問題: CTCがアライメント情報なしで学習できる理由を説明し、Blankトークンの役割を述べてください。

解答例:

CTCは全ての可能なアライメントパスの確率を周辺化することで、明示的なアライメント情報なしに学習できます。具体的には:

問題2: AttentionメカニズムとCTCの違い

問題: Attention-basedモデルとCTCベースのモデルの主な違いを、アーキテクチャと学習の観点から説明してください。

解答例:

観点 CTC Attention-based
アーキテクチャ Encoder + 線形分類器 Encoder + Attention + Decoder
アライメント Monotonic(単調) 柔軟(Attentionで決定)
言語モデル 条件付き独立(外部LM必要) Decoderに統合
計算コスト 低い 高い
長距離依存 苦手 得意
問題3: RNN-Tの実装

問題: RNN-Transducerの3つの主要コンポーネント(Encoder, Prediction Network, Joint Network)の役割を説明し、簡単な実装を示してください。

解答: 本文中のRNN-Tの実装を参照してください。各コンポーネントの役割:

問題4: Whisperのファインチューニング

問題: Whisperを特定のドメイン(例: 医療、法律)にファインチューニングする際の主な考慮点を述べてください。

解答例:

  1. データセット:
    • ドメイン固有の音声データを収集
    • 専門用語の正確な転写
    • 多様な話者と音響条件
  2. 語彙の拡張:
    • ドメイン固有の用語をトークナイザーに追加
    • 略語や専門表記の処理
  3. 学習率と正則化:
    • 低い学習率(1e-5程度)で慎重に学習
    • 過学習防止のためのドロップアウト
  4. 評価:
    • ドメイン固有のテストセットでWER評価
    • 専門用語の認識精度を個別に評価
問題5: リアルタイムASRの最適化

問題: リアルタイム音声認識システムにおいて、レイテンシと精度のトレードオフを最適化する方法を3つ挙げてください。

解答例:

  1. モデルサイズの選択:
    • リアルタイム処理には軽量モデル(tiny/base)を使用
    • 必要に応じてモデル蒸留で小型化
  2. ストリーミング処理:
    • RNN-Tなどストリーミング対応アーキテクチャを使用
    • チャンクサイズと認識精度のバランス調整
    • ルックアヘッドの制限
  3. ハードウェア最適化:
    • GPU/TPUの活用
    • 量子化(INT8など)で推論を高速化
    • バッチ処理の活用

まとめ

本章では、深層学習による音声認識の主要技術を学びました:

これらの技術を組み合わせることで、様々なシナリオに対応した高精度な音声認識システムを構築できます。次章では、音声合成(TTS)と音声変換について学びます。

参考文献