🌐 EN | đŸ‡¯đŸ‡ĩ JP | Last sync: 2025-11-16

Chapter 3: Speech Recognition with Deep Learning

State-of-the-art Speech Recognition with CTC, Attention, RNN-T, and Whisper

📚 Introduction to Speech and Audio Processing âąī¸ 60 minutes đŸˇī¸ ML-D03

What You Will Learn in This Chapter

This chapter introduces modern Automatic Speech Recognition (ASR) techniques using deep learning. We cover fundamental approaches such as CTC and Attention mechanisms, as well as state-of-the-art models like Whisper, with practical code examples.

1. CTC (Connectionist Temporal Classification)

1.1 What is CTC

CTC is a method that can learn sequence-to-sequence transformations where input and output lengths differ, without requiring explicit alignment information. In speech recognition, it automatically learns the correspondence between acoustic feature frames (input) and text (output).

Key Concepts of CTC

1.2 CTC Loss Function

CTC loss is computed by marginalizing the probabilities of all possible alignment paths. It can be efficiently calculated using the Forward-Backward algorithm.

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

import torch
import torch.nn as nn
import torch.nn.functional as F

class CTCASRModel(nn.Module):
    """Speech recognition model using CTC loss"""

    def __init__(self, input_dim=80, hidden_dim=256, num_classes=29, num_layers=3):
        """
        Args:
            input_dim: inputfeatures dimensions(MFCC or Mel Spectrogram)
            hidden_dim: Dimension of LSTM hidden layer
            num_classes: Number of output classes (alphabet + blank)
            num_layers: Number of LSTM layers
        """
        super(CTCASRModel, self).__init__()

        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

        # Output dimension is doubled due to bidirectional LSTM
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x, lengths):
        """
        Args:
            x: (batch, time, features)
            lengths: Actual length of each sample
        Returns:
            log_probs: (time, batch, num_classes)
            output_lengths: Output length of each sample
        """
        # Efficiently process variable-length inputs with PackedSequence
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        packed_output, _ = self.lstm(packed)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True
        )

        # Apply log softmax for CTC
        logits = self.classifier(output)
        log_probs = F.log_softmax(logits, dim=-1)

        # CTC expects (T, N, C) format
        log_probs = log_probs.transpose(0, 1)

        return log_probs, output_lengths


# Training example
def train_ctc_model():
    """CTC ASR model training example"""

    # Initialize model and CTC loss
    model = CTCASRModel(input_dim=80, hidden_dim=256, num_classes=29)
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Dummy data (in practice, obtained from data loader)
    batch_size = 4
    max_time = 100
    features = torch.randn(batch_size, max_time, 80)
    feature_lengths = torch.tensor([100, 95, 90, 85])

    # Target text (numerically encoded)
    targets = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 10, 0], [11, 12, 0, 0]])
    target_lengths = torch.tensor([4, 3, 3, 2])

    model.train()

    # Forward pass
    log_probs, output_lengths = model(features, feature_lengths)

    # Calculate CTC loss
    loss = ctc_loss(log_probs, targets, output_lengths, target_lengths)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"CTC Loss: {loss.item():.4f}")

    return model

# Execution example
if __name__ == "__main__":
    model = train_ctc_model()
    print("✓ CTC model training completed")

1.3 CTC Decoding

To obtain text from a trained CTC model, decoding is required. The main methods are Greedy Decoding and Beam Search.

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0

import numpy as np
from collections import defaultdict

class CTCDecoder:
    """CTC Decoder (Greedy & Beam Search)"""

    def __init__(self, labels, blank_idx=0):
        """
        Args:
            labels: List of character labels
            blank_idx: Index of blank symbol
        """
        self.labels = labels
        self.blank_idx = blank_idx

    def greedy_decode(self, log_probs):
        """
        Greedy Decoding: At each time step, the most probability HighSelect label

        Args:
            log_probs: (time, num_classes) log probabilities
        Returns:
            decoded_text: Decoded text
        """
        # At each time step, the most probability HighGet index
        best_path = torch.argmax(log_probs, dim=-1)

        # Remove consecutive duplicates and blanks
        decoded = []
        prev_idx = self.blank_idx

        for idx in best_path:
            idx = idx.item()
            if idx != self.blank_idx and idx != prev_idx:
                decoded.append(self.labels[idx])
            prev_idx = idx

        return ''.join(decoded)

    def beam_search_decode(self, log_probs, beam_width=10):
        """
        Beam Search Decoding: More accurate decoding

        Args:
            log_probs: (time, num_classes) log probabilities
            beam_width: Beam width
        Returns:
            decoded_text: Decoded text
        """
        T, C = log_probs.shape
        log_probs = log_probs.cpu().numpy()

        # Beams: {sequence: probability}
        beams = {('', self.blank_idx): 0.0}  # (text, last_char): log_prob

        for t in range(T):
            new_beams = defaultdict(lambda: float('-inf'))

            for (text, last_char), log_prob in beams.items():
                for c in range(C):
                    new_log_prob = log_prob + log_probs[t, c]

                    if c == self.blank_idx:
                        # Do not modify text for blank
                        new_beams[(text, c)] = np.logaddexp(
                            new_beams[(text, c)], new_log_prob
                        )
                    else:
                        if c == last_char:
                            # Do not repeat if same as previous character
                            new_beams[(text, c)] = np.logaddexp(
                                new_beams[(text, c)], new_log_prob
                            )
                        else:
                            # Add new character
                            new_text = text + self.labels[c]
                            new_beams[(new_text, c)] = np.logaddexp(
                                new_beams[(new_text, c)], new_log_prob
                            )

            # Keep top beam_width beams
            beams = dict(sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width])

        # mostprobability HighReturn beam
        best_beam = max(beams.items(), key=lambda x: x[1])
        return best_beam[0][0]


# Decodingexample
def decode_example():
    """CTC decoding execution example"""

    # Alphabet (0 is blank)
    labels = ['-', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
              'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', "'"]

    decoder = CTCDecoder(labels, blank_idx=0)

    # Dummy log probabilities (in practice, model output)
    T = 50
    C = len(labels)
    log_probs = torch.randn(T, C)
    log_probs = F.log_softmax(log_probs, dim=-1)

    # Greedy Decoding
    greedy_text = decoder.greedy_decode(log_probs)
    print(f"Greedy Decode: {greedy_text}")

    # Beam Search Decoding
    beam_text = decoder.beam_search_decode(log_probs, beam_width=10)
    print(f"Beam Search Decode: {beam_text}")

decode_example()

2. Attention-based Models

2.1 Listen, Attend and Spell (LAS)

LAS is a speech recognition model that combines an Encoder-Decoder architecture with an Attention mechanism. The Encoder converts acoustic features into high-level representations, and the Decoder generates text while focusing on necessary information through Attention.

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

import torch
import torch.nn as nn
import torch.nn.functional as F

class ListenAttendSpell(nn.Module):
    """Listen, Attend and Spell model"""

    def __init__(self, input_dim=80, encoder_hidden=256, decoder_hidden=512,
                 vocab_size=29, num_layers=3):
        super(ListenAttendSpell, self).__init__()

        self.encoder = Listener(input_dim, encoder_hidden, num_layers)
        self.decoder = Speller(encoder_hidden * 2, decoder_hidden, vocab_size)

    def forward(self, inputs, input_lengths, targets=None, teacher_forcing_ratio=0.9):
        """
        Args:
            inputs: (batch, time, features)
            input_lengths: Length of each sample
            targets: (batch, target_len) Decoder target
            teacher_forcing_ratio: Teacher forcing ratio
        """
        # Encoder
        encoder_outputs, encoder_lengths = self.encoder(inputs, input_lengths)

        # Decoder
        if targets is not None:
            outputs = self.decoder(encoder_outputs, encoder_lengths, targets,
                                  teacher_forcing_ratio)
        else:
            outputs = self.decoder.inference(encoder_outputs, encoder_lengths)

        return outputs


class Listener(nn.Module):
    """Encoder: acousticfeatures to high-level representations"""

    def __init__(self, input_dim, hidden_dim, num_layers):
        super(Listener, self).__init__()

        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

        # Pyramidal LSTM: Compress time dimension to improve computational efficiency
        self.pyramid_lstm = nn.LSTM(
            hidden_dim * 4,  # 2 frames combined + bidirectional
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )

    def forward(self, x, lengths):
        """
        Args:
            x: (batch, time, features)
            lengths: Length of each sample
        """
        # First layer LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output, _ = self.lstm(packed)
        output, lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

        # Pyramidal: Combine 2 frames into 1 to halve time
        batch, time, features = output.size()
        if time % 2 == 1:
            output = output[:, :-1, :]
            time -= 1

        output = output.reshape(batch, time // 2, features * 2)
        lengths = lengths // 2

        # Second layer LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            output, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output, _ = self.pyramid_lstm(packed)
        output, lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

        return output, lengths


class Speller(nn.Module):
    """Decoder: Generate text using Attention"""

    def __init__(self, encoder_dim, hidden_dim, vocab_size):
        super(Speller, self).__init__()

        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.lstm = nn.LSTMCell(hidden_dim + encoder_dim, hidden_dim)
        self.attention = BahdanauAttention(encoder_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim + encoder_dim, vocab_size)

    def forward(self, encoder_outputs, encoder_lengths, targets, teacher_forcing_ratio=0.9):
        """
        Training with Teacher Forcing
        """
        batch_size = encoder_outputs.size(0)
        max_len = targets.size(1)

        # Initial state
        hidden = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)
        cell = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)

        # Start token
        input_token = torch.zeros(batch_size, dtype=torch.long, device=encoder_outputs.device)

        outputs = []

        for t in range(max_len):
            # Embedding
            embedded = self.embedding(input_token)

            # Attention
            context, _ = self.attention(hidden, encoder_outputs, encoder_lengths)

            # LSTM
            lstm_input = torch.cat([embedded, context], dim=1)
            hidden, cell = self.lstm(lstm_input, (hidden, cell))

            # Output
            output = self.classifier(torch.cat([hidden, context], dim=1))
            outputs.append(output)

            # Teacher Forcing
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
            if use_teacher_forcing:
                input_token = targets[:, t]
            else:
                input_token = output.argmax(dim=1)

        return torch.stack(outputs, dim=1)

    def inference(self, encoder_outputs, encoder_lengths, max_len=100):
        """Decoding during inference"""
        batch_size = encoder_outputs.size(0)

        hidden = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)
        cell = torch.zeros(batch_size, self.lstm.hidden_size, device=encoder_outputs.device)

        input_token = torch.zeros(batch_size, dtype=torch.long, device=encoder_outputs.device)

        outputs = []

        for t in range(max_len):
            embedded = self.embedding(input_token)
            context, _ = self.attention(hidden, encoder_outputs, encoder_lengths)

            lstm_input = torch.cat([embedded, context], dim=1)
            hidden, cell = self.lstm(lstm_input, (hidden, cell))

            output = self.classifier(torch.cat([hidden, context], dim=1))
            outputs.append(output)

            input_token = output.argmax(dim=1)

            # Stop at end token
            if (input_token == self.vocab_size - 1).all():
                break

        return torch.stack(outputs, dim=1)


class BahdanauAttention(nn.Module):
    """Bahdanau Attention mechanism"""

    def __init__(self, encoder_dim, decoder_dim):
        super(BahdanauAttention, self).__init__()

        self.encoder_projection = nn.Linear(encoder_dim, decoder_dim)
        self.decoder_projection = nn.Linear(decoder_dim, decoder_dim)
        self.v = nn.Linear(decoder_dim, 1)

    def forward(self, decoder_hidden, encoder_outputs, encoder_lengths):
        """
        Args:
            decoder_hidden: (batch, decoder_dim)
            encoder_outputs: (batch, time, encoder_dim)
            encoder_lengths: (batch,)
        Returns:
            context: (batch, encoder_dim)
            attention_weights: (batch, time)
        """
        batch_size, time_steps, _ = encoder_outputs.size()

        # Projection
        encoder_proj = self.encoder_projection(encoder_outputs)  # (batch, time, decoder_dim)
        decoder_proj = self.decoder_projection(decoder_hidden).unsqueeze(1)  # (batch, 1, decoder_dim)

        # Energy computation
        energy = self.v(torch.tanh(encoder_proj + decoder_proj)).squeeze(-1)  # (batch, time)

        # Mask (ignore padding)
        mask = torch.arange(time_steps, device=encoder_outputs.device).expand(
            batch_size, time_steps
        ) < encoder_lengths.unsqueeze(1)

        energy = energy.masked_fill(~mask, float('-inf'))

        # Attention weights
        attention_weights = F.softmax(energy, dim=1)  # (batch, time)

        # Context vector
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attention_weights


# Usage example
def las_example():
    """LAS model usage example"""

    model = ListenAttendSpell(
        input_dim=80,
        encoder_hidden=256,
        decoder_hidden=512,
        vocab_size=29
    )

    # Dummy data
    batch_size = 4
    inputs = torch.randn(batch_size, 100, 80)
    input_lengths = torch.tensor([100, 95, 90, 85])
    targets = torch.randint(0, 29, (batch_size, 20))

    # Training mode
    outputs = model(inputs, input_lengths, targets, teacher_forcing_ratio=0.9)
    print(f"Training output shape: {outputs.shape}")

    # Inference mode
    model.eval()
    with torch.no_grad():
        predictions = model(inputs, input_lengths, targets=None)
        print(f"Inference output shape: {predictions.shape}")

las_example()

2.2 Transformer for ASR

Applying the Transformer architecture to ASR enables parallel processing and more effectively captures long-range dependencies.

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

import torch
import torch.nn as nn
import math

class TransformerASR(nn.Module):
    """Transformer based ASR model"""

    def __init__(self, input_dim=80, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, vocab_size=29, dropout=0.1):
        super(TransformerASR, self).__init__()

        # inputProjection
        self.input_projection = nn.Linear(input_dim, d_model)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )

        # Output layer
        self.output_projection = nn.Linear(d_model, vocab_size)

        # Parameters
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Embedding (for decoder input)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        """
        Args:
            src: (batch, src_len, input_dim)
            tgt: (batch, tgt_len)
            src_mask: Encoder self-attention mask
            tgt_mask: Decoder self-attention mask (causal)
            src_key_padding_mask: (batch, src_len) Padding mask
            tgt_key_padding_mask: (batch, tgt_len) Padding mask
        """
        # Prepare encoder input
        src = self.input_projection(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        # Prepare decoder input
        tgt_embedded = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_embedded = self.pos_encoder(tgt_embedded)

        # Transformer
        output = self.transformer(
            src, tgt_embedded,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        # OutputProjection
        output = self.output_projection(output)

        return output

    def generate_square_subsequent_mask(self, sz):
        """Generate causal mask (to prevent seeing future information)"""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
        return mask


class PositionalEncoding(nn.Module):
    """Positional Encoding for Transformer"""

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Precompute positional encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


# Training example
def train_transformer_asr():
    """Transformer ASR training example"""

    model = TransformerASR(
        input_dim=80,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        vocab_size=29
    )

    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is padding
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98))

    # Dummy data
    batch_size = 8
    src = torch.randn(batch_size, 100, 80)
    tgt = torch.randint(1, 29, (batch_size, 30))

    # Padding mask
    src_key_padding_mask = torch.zeros(batch_size, 100).bool()
    tgt_key_padding_mask = torch.zeros(batch_size, 30).bool()

    # Causal mask
    tgt_mask = model.generate_square_subsequent_mask(30).to(tgt.device)

    # Forward
    output = model(src, tgt[:, :-1], tgt_mask=tgt_mask,
                   src_key_padding_mask=src_key_padding_mask,
                   tgt_key_padding_mask=tgt_key_padding_mask[:, :-1])

    # Loss calculation
    loss = criterion(output.reshape(-1, model.vocab_size), tgt[:, 1:].reshape(-1))

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Transformer ASR Loss: {loss.item():.4f}")

    return model

train_transformer_asr()

2.3 Joint CTC-Attention

By combining CTC and Attention, we can leverage the strengths of each. CTC helps learn alignment, while Attention utilizes contextual information.

class JointCTCAttention(nn.Module):
    """Hybrid model combining CTC and Attention"""

    def __init__(self, input_dim=80, encoder_hidden=256, decoder_hidden=512,
                 vocab_size=29, num_layers=3, ctc_weight=0.3):
        super(JointCTCAttention, self).__init__()

        # Shared Encoder
        self.encoder = Listener(input_dim, encoder_hidden, num_layers)

        # CTC classifier
        self.ctc_classifier = nn.Linear(encoder_hidden * 2, vocab_size)

        # Attention-based Decoder
        self.decoder = Speller(encoder_hidden * 2, decoder_hidden, vocab_size)

        # CTC loss weight
        self.ctc_weight = ctc_weight
        self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

    def forward(self, inputs, input_lengths, targets, target_lengths=None,
                teacher_forcing_ratio=0.9):
        """
        Args:
            inputs: (batch, time, features)
            input_lengths: Length of acoustic features for each sample
            targets: (batch, target_len) Text target
            target_lengths: Length of each target (for CTC)
            teacher_forcing_ratio: Teacher forcing ratio
        Returns:
            ctc_loss: CTC loss
            attention_loss: Attention loss
            combined_loss: Combined loss
        """
        # Encoder (shared)
        encoder_outputs, encoder_lengths = self.encoder(inputs, input_lengths)

        # CTC branch
        ctc_logits = self.ctc_classifier(encoder_outputs)
        ctc_log_probs = F.log_softmax(ctc_logits, dim=-1)
        ctc_log_probs = ctc_log_probs.transpose(0, 1)  # (T, N, C)

        # Attention branch
        attention_outputs = self.decoder(encoder_outputs, encoder_lengths,
                                        targets, teacher_forcing_ratio)

        return ctc_log_probs, encoder_lengths, attention_outputs

    def compute_loss(self, ctc_log_probs, encoder_lengths, attention_outputs,
                     targets, target_lengths):
        """
        Loss calculation
        """
        # CTC loss
        ctc_loss = self.ctc_loss(ctc_log_probs, targets, encoder_lengths, target_lengths)

        # Attention loss (Cross Entropy)
        attention_loss = F.cross_entropy(
            attention_outputs.reshape(-1, attention_outputs.size(-1)),
            targets.reshape(-1),
            ignore_index=0
        )

        # Combined loss
        combined_loss = self.ctc_weight * ctc_loss + (1 - self.ctc_weight) * attention_loss

        return ctc_loss, attention_loss, combined_loss

    def recognize(self, inputs, input_lengths, beam_width=10):
        """
        Inference: Decoding using both CTC and Attention
        """
        self.eval()
        with torch.no_grad():
            # Encoder
            encoder_outputs, encoder_lengths = self.encoder(inputs, input_lengths)

            # CTC branch (For prefix beam search)
            ctc_logits = self.ctc_classifier(encoder_outputs)
            ctc_probs = F.softmax(ctc_logits, dim=-1)

            # Attention branch
            attention_outputs = self.decoder.inference(encoder_outputs, encoder_lengths)
            attention_probs = F.softmax(attention_outputs, dim=-1)

            # Decode by combining CTC and Attention scores
            # (In practice, more complex beam search algorithms are used)
            combined_probs = (self.ctc_weight * ctc_probs[0] +
                            (1 - self.ctc_weight) * attention_probs)

            predictions = combined_probs.argmax(dim=-1)

            return predictions


# Training example
def train_joint_model():
    """Joint CTC-Attention model training"""

    model = JointCTCAttention(
        input_dim=80,
        encoder_hidden=256,
        decoder_hidden=512,
        vocab_size=29,
        ctc_weight=0.3
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Dummy data
    batch_size = 4
    inputs = torch.randn(batch_size, 100, 80)
    input_lengths = torch.tensor([100, 95, 90, 85])
    targets = torch.randint(1, 29, (batch_size, 20))
    target_lengths = torch.tensor([20, 18, 17, 15])

    # Forward
    ctc_log_probs, encoder_lengths, attention_outputs = model(
        inputs, input_lengths, targets, target_lengths
    )

    # Calculate loss
    ctc_loss, attention_loss, combined_loss = model.compute_loss(
        ctc_log_probs, encoder_lengths, attention_outputs, targets, target_lengths
    )

    # Backward
    optimizer.zero_grad()
    combined_loss.backward()
    optimizer.step()

    print(f"CTC Loss: {ctc_loss.item():.4f}")
    print(f"Attention Loss: {attention_loss.item():.4f}")
    print(f"Combined Loss: {combined_loss.item():.4f}")

    return model

train_joint_model()

3. RNN-Transducer (RNN-T)

3.1 What is RNN-T

RNN-Transducer is a model suitable for streaming speech recognition. It consists of three components: an acoustic model (Encoder), a language model (Prediction Network), and a Joint Network. Unlike CTC, it can explicitly incorporate a language model.

Characteristics of RNN-T

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNTransducer(nn.Module):
    """RNN-Transducer model"""

    def __init__(self, input_dim=80, encoder_dim=256, pred_dim=256,
                 joint_dim=512, vocab_size=29, num_layers=3):
        """
        Args:
            input_dim: input feature dimensions
            encoder_dim: Encoder hidden layer dimension
            pred_dim: Prediction Network hidden layer dimension
            joint_dim: Joint Network hidden layer dimension
            vocab_size: Vocabulary size (including blank)
        """
        super(RNNTransducer, self).__init__()

        # Encoder (Transcription Network)
        self.encoder = nn.LSTM(
            input_dim,
            encoder_dim,
            num_layers,
            batch_first=True,
            bidirectional=False,  # Streaming for unidirectional
            dropout=0.2
        )

        # Prediction Network (language model)
        self.embedding = nn.Embedding(vocab_size, pred_dim)
        self.prediction = nn.LSTM(
            pred_dim,
            pred_dim,
            num_layers,
            batch_first=True,
            dropout=0.2
        )

        # Joint Network
        self.joint = JointNetwork(encoder_dim, pred_dim, joint_dim, vocab_size)

        self.vocab_size = vocab_size
        self.blank_idx = 0

    def forward(self, inputs, input_lengths, targets, target_lengths):
        """
        Args:
            inputs: (batch, time, features) acousticfeatures
            input_lengths: Input length of each sample
            targets: (batch, target_len) Target labels
            target_lengths: Target length of each sample
        Returns:
            joint_output: (batch, time, target_len+1, vocab_size)
        """
        # Encoder
        encoder_out, _ = self.encoder(inputs)  # (batch, time, encoder_dim)

        # Prediction Network
        # Start token add
        batch_size = targets.size(0)
        start_tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=targets.device)
        pred_input = torch.cat([start_tokens, targets], dim=1)

        pred_embedded = self.embedding(pred_input)
        pred_out, _ = self.prediction(pred_embedded)  # (batch, target_len+1, pred_dim)

        # Joint Network
        joint_output = self.joint(encoder_out, pred_out)

        return joint_output

    def greedy_decode(self, inputs, input_lengths, max_len=100):
        """
        Greedy decoding for inference
        """
        self.eval()
        with torch.no_grad():
            batch_size = inputs.size(0)

            # Encoder
            encoder_out, _ = self.encoder(inputs)
            time_steps = encoder_out.size(1)

            # Initialization
            predictions = []
            pred_hidden = None

            # Start token
            pred_input = torch.zeros(batch_size, 1, dtype=torch.long, device=inputs.device)

            for t in range(time_steps):
                # Current encoder output
                enc_t = encoder_out[:, t:t+1, :]  # (batch, 1, encoder_dim)

                # Prediction Network
                pred_embedded = self.embedding(pred_input)
                pred_out, pred_hidden = self.prediction(pred_embedded, pred_hidden)

                # Joint Network
                joint_out = self.joint(enc_t, pred_out)  # (batch, 1, 1, vocab_size)

                # mostprobability HighTokens select
                prob = F.softmax(joint_out.squeeze(1).squeeze(1), dim=-1)
                pred_token = prob.argmax(dim=-1)

                # Add to output only if not blank
                if pred_token.item() != self.blank_idx:
                    predictions.append(pred_token.item())
                    pred_input = pred_token.unsqueeze(1)

                if len(predictions) >= max_len:
                    break

            return predictions


class JointNetwork(nn.Module):
    """Joint Network: Combine outputs of Encoder and Prediction Network"""

    def __init__(self, encoder_dim, pred_dim, joint_dim, vocab_size):
        super(JointNetwork, self).__init__()

        self.encoder_proj = nn.Linear(encoder_dim, joint_dim)
        self.pred_proj = nn.Linear(pred_dim, joint_dim)
        self.output_proj = nn.Linear(joint_dim, vocab_size)

    def forward(self, encoder_out, pred_out):
        """
        Args:
            encoder_out: (batch, time, encoder_dim)
            pred_out: (batch, target_len, pred_dim)
        Returns:
            joint_out: (batch, time, target_len, vocab_size)
        """
        # Projection
        enc_proj = self.encoder_proj(encoder_out)  # (batch, time, joint_dim)
        pred_proj = self.pred_proj(pred_out)  # (batch, target_len, joint_dim)

        # Broadcast and add
        # (batch, time, 1, joint_dim) + (batch, 1, target_len, joint_dim)
        joint = torch.tanh(
            enc_proj.unsqueeze(2) + pred_proj.unsqueeze(1)
        )  # (batch, time, target_len, joint_dim)

        # OutputProjection
        output = self.output_proj(joint)  # (batch, time, target_len, vocab_size)

        return output


# RNN-T Loss (Simplified version)
class RNNTLoss(nn.Module):
    """RNN-T loss function (Forward-Backward algorithm)"""

    def __init__(self, blank_idx=0):
        super(RNNTLoss, self).__init__()
        self.blank_idx = blank_idx

    def forward(self, logits, targets, input_lengths, target_lengths):
        """
        Args:
            logits: (batch, time, target_len+1, vocab_size)
            targets: (batch, target_len)
            input_lengths: (batch,)
            target_lengths: (batch,)
        """
        # Use PyTorch torchaudio.functional.rnnt_loss
        # Here for simplicityCTC loss as approximation
        batch_size, time, _, vocab_size = logits.size()

        # Greedy path approximation
        log_probs = F.log_softmax(logits, dim=-1)

        # Calculate probabilities of blank and non-blank at each time step
        loss = 0
        for b in range(batch_size):
            for t in range(input_lengths[b]):
                for u in range(target_lengths[b] + 1):
                    if u < target_lengths[b]:
                        # Non-blank: probability of correct label
                        target_label = targets[b, u]
                        loss -= log_probs[b, t, u, target_label]
                    else:
                        # blank
                        loss -= log_probs[b, t, u, self.blank_idx]

        return loss / batch_size


# Usage example
def train_rnnt():
    """RNN-T model training example"""

    model = RNNTransducer(
        input_dim=80,
        encoder_dim=256,
        pred_dim=256,
        joint_dim=512,
        vocab_size=29
    )

    criterion = RNNTLoss(blank_idx=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Dummy data
    batch_size = 4
    inputs = torch.randn(batch_size, 100, 80)
    input_lengths = torch.tensor([100, 95, 90, 85])
    targets = torch.randint(1, 29, (batch_size, 20))
    target_lengths = torch.tensor([20, 18, 17, 15])

    # Forward
    logits = model(inputs, input_lengths, targets, target_lengths)

    # Loss
    loss = criterion(logits, targets, input_lengths, target_lengths)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"RNN-T Loss: {loss.item():.4f}")

    # Inference example
    predictions = model.greedy_decode(inputs[:1], input_lengths[:1])
    print(f"Predicted tokens: {predictions}")

    return model

train_rnnt()

3.2 Comparison of CTC and RNN-T

feature CTC RNN-Transducer
language model External LM required Internally integrated in Prediction Network
Streaming Possible (when using unidirectional LSTM) Designed for streaming
Computational Cost Low Somewhat high (Joint Network)
accuracy Moderate High (language model integration effect)
Training Stability Relatively stable Can be somewhat unstable

4. Whisper

4.1 What is OpenAI Whisper

Whisper is a multilingual speech recognition model developed by OpenAI. Trained on 680,000 hours of multilingual data, it supports speech recognition, translation, and language identification tasks for 99 languages.

Characteristics of Whisper

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0
# - torch>=2.0.0, <2.3.0

import whisper
import torch
import numpy as np

# Basic Whisper model usage
def basic_whisper_usage():
    """Basic usage of Whisper"""

    # Load model (choose from base, small, medium, large)
    model = whisper.load_model("base")

    print(f"Model size: {model.dims}")
    print(f"Number of supported languages: {len(whisper.tokenizer.LANGUAGES)}")

    # Load audio file and transcribe
    # audio = whisper.load_audio("audio.mp3")
    # audio = whisper.pad_or_trim(audio)

    # Dummy audio (in practice, loaded from file)
    audio = np.random.randn(16000 * 10).astype(np.float32)  # 10seconds of audio

    # Mel Spectrogram convert to
    mel = whisper.log_mel_spectrogram(torch.from_numpy(audio)).to(model.device)

    # Language detection
    _, probs = model.detect_language(mel)
    detected_language = max(probs, key=probs.get)
    print(f"Detected language: {detected_language} (probability: {probs[detected_language]:.2f})")

    # Options configuration
    options = whisper.DecodingOptions(
        language="ja",  # Specify Japanese
        task="transcribe",  # transcribe or translate
        fp16=False  # Whether to use FP16
    )

    # Decoding
    result = whisper.decode(model, mel, options)

    print(f"Transcription result: {result.text}")
    print(f"Average logprobability: {result.avg_logprob:.4f}")
    print(f"Compression ratio: {result.compression_ratio:.2f}")

    return model, result


# Higher-level API
def transcribe_audio_file(audio_path, model_size="base"):
    """
    Transcribe audio file

    Args:
        audio_path: Audio file path
        model_size: Model size (tiny, base, small, medium, large)
    """
    # Load model
    model = whisper.load_model(model_size)

    # Transcription
    result = model.transcribe(
        audio_path,
        language="ja",  # Japanese
        task="transcribe",
        verbose=True,  # Show progress
        temperature=0.0,  # Temperature (diversity control)
        best_of=5,  # Number of beam search candidates
        beam_size=5,  # Beam width
        patience=1.0,  # Beam search patience
        length_penalty=1.0,  # Length penalty
        compression_ratio_threshold=2.4,  # Compression ratio threshold
        logprob_threshold=-1.0,  # Log probability threshold
        no_speech_threshold=0.6  # No-speech threshold
    )

    # Display results
    print("=" * 50)
    print("Transcription result:")
    print("=" * 50)
    print(result["text"])
    print()

    # Results by segment
    print("Segment details:")
    for segment in result["segments"]:
        start = segment["start"]
        end = segment["end"]
        text = segment["text"]
        print(f"[{start:.2f}s - {end:.2f}s] {text}")

    return result


# With timestampsTranscription
def transcribe_with_timestamps(audio_path):
    """With timestamps withTranscription"""

    model = whisper.load_model("base")

    # word_timestamps=True withGet word-level timestamps
    result = model.transcribe(
        audio_path,
        language="ja",
        word_timestamps=True  # Word-level timestamps
    )

    # Display word-level results
    for segment in result["segments"]:
        if "words" in segment:
            for word_info in segment["words"]:
                word = word_info["word"]
                start = word_info["start"]
                end = word_info["end"]
                probability = word_info.get("probability", 0.0)
                print(f"{word:15s} [{start:6.2f}s - {end:6.2f}s] (probability: {probability:.3f})")

    return result


# Multilingual audio processing
def multilingual_transcription(audio_path):
    """Multilingual audio processing"""

    model = whisper.load_model("medium")  # Medium or larger recommended for multilingual

    # Auto-detect language
    result = model.transcribe(
        audio_path,
        task="transcribe",
        language=None  # Auto-detect
    )

    detected_language = result["language"]
    print(f"Detected language: {detected_language}")
    print(f"Transcription: {result['text']}")

    # Translate to English
    translation = model.transcribe(
        audio_path,
        task="translate",  # Translate to English
        language=detected_language
    )

    print(f"English translation: {translation['text']}")

    return result, translation


# Execution example
if __name__ == "__main__":
    print("Whisper usage examples")
    print("=" * 50)

    # Basic usage
    model, result = basic_whisper_usage()
    print("✓ basicTranscription completed")

    # Note: Uncomment the following to use actual audio files
    # result = transcribe_audio_file("sample.mp3", model_size="base")
    # result = transcribe_with_timestamps("sample.mp3")
    # result, translation = multilingual_transcription("sample.mp3")

4.2 Whisper architecture

Whisper adopts an Encoder-Decoder architecture. The Encoder processes acoustic features, and the Decoder generates text. Both are Transformer-based.

graph LR A[audio input] --> B[Mel Spectrogram] B --> C[Encoder
Transformer] C --> D[Acoustic Representation] D --> E[Decoder
Transformer] E --> F[Text Output] style A fill:#e1f5ff style F fill:#e1f5ff style C fill:#fff4e1 style E fill:#fff4e1
# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

import torch
import torch.nn as nn
from typing import Optional

class Whisperarchitecture(nn.Module):
    """
    Whisper architectureoverview
    (Actual Whisper is more complex, but this shows the main structure)
    """

    def __init__(self,
                 n_mels=80,
                 n_audio_ctx=1500,
                 n_audio_state=512,
                 n_audio_head=8,
                 n_audio_layer=6,
                 n_vocab=51865,
                 n_text_ctx=448,
                 n_text_state=512,
                 n_text_head=8,
                 n_text_layer=6):
        """
        Args:
            n_mels: Number of mel filter banks
            n_audio_ctx: Audio context length
            n_audio_state: Encoder state dimension
            n_audio_head: Number of encoder attention heads
            n_audio_layer: Number of encoder layers
            n_vocab: Vocabulary size
            n_text_ctx: Text context length
            n_text_state: Decoder state dimension
            n_text_head: Number of decoder attention heads
            n_text_layer: Number of decoder layers
        """
        super().__init__()

        # Encoder: acousticfeatures process
        self.encoder = AudioEncoder(
            n_mels=n_mels,
            n_ctx=n_audio_ctx,
            n_state=n_audio_state,
            n_head=n_audio_head,
            n_layer=n_audio_layer
        )

        # Decoder: Generate text
        self.decoder = TextDecoder(
            n_vocab=n_vocab,
            n_ctx=n_text_ctx,
            n_state=n_text_state,
            n_head=n_text_head,
            n_layer=n_text_layer
        )

    def forward(self, mel, tokens):
        """
        Args:
            mel: (batch, n_mels, time) Mel Spectrogram
            tokens: (batch, seq_len) Tokens
        """
        # Encoder
        audio_features = self.encoder(mel)

        # Decoder
        logits = self.decoder(tokens, audio_features)

        return logits


class AudioEncoder(nn.Module):
    """Whisper Audio Encoder"""

    def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
        super().__init__()

        # Convolutional layers with feature extraction
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)

        # Positional embedding
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        # Transformer layers
        self.blocks = nn.ModuleList([
            TransformerBlock(n_state, n_head) for _ in range(n_layer)
        ])

        self.ln_post = nn.LayerNorm(n_state)

    def forward(self, x):
        """
        Args:
            x: (batch, n_mels, time)
        Returns:
            (batch, time//2, n_state)
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)  # (batch, time, n_state)

        # Positional embedding
        x = x + self.positional_embedding[:x.size(1)]

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x


class TextDecoder(nn.Module):
    """Whisper Text Decoder"""

    def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
        super().__init__()

        # Token embedding
        self.token_embedding = nn.Embedding(n_vocab, n_state)

        # Positional embedding
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        # Transformer layers (with cross-attention)
        self.blocks = nn.ModuleList([
            DecoderBlock(n_state, n_head) for _ in range(n_layer)
        ])

        self.ln = nn.LayerNorm(n_state)

    def forward(self, tokens, audio_features):
        """
        Args:
            tokens: (batch, seq_len)
            audio_features: (batch, audio_len, n_state)
        Returns:
            (batch, seq_len, n_vocab)
        """
        x = self.token_embedding(tokens)
        x = x + self.positional_embedding[:x.size(1)]

        for block in self.blocks:
            x = block(x, audio_features)

        x = self.ln(x)

        # Weight tying: token embedding reuse
        logits = x @ self.token_embedding.weight.T

        return logits


class TransformerBlock(nn.Module):
    """Transformer Encoder Block"""

    def __init__(self, n_state, n_head):
        super().__init__()

        self.attn = nn.MultiheadAttention(n_state, n_head, batch_first=True)
        self.attn_ln = nn.LayerNorm(n_state)

        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_state * 4),
            nn.GELU(),
            nn.Linear(n_state * 4, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attn(x, x, x)
        x = self.attn_ln(x + attn_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.mlp_ln(x + mlp_out)

        return x


class DecoderBlock(nn.Module):
    """Transformer Decoder Block (with cross-attention)"""

    def __init__(self, n_state, n_head):
        super().__init__()

        # Self-attention
        self.self_attn = nn.MultiheadAttention(n_state, n_head, batch_first=True)
        self.self_attn_ln = nn.LayerNorm(n_state)

        # Cross-attention
        self.cross_attn = nn.MultiheadAttention(n_state, n_head, batch_first=True)
        self.cross_attn_ln = nn.LayerNorm(n_state)

        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_state * 4),
            nn.GELU(),
            nn.Linear(n_state * 4, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x, audio_features):
        # Self-attention (causal mask)
        attn_out, _ = self.self_attn(x, x, x, need_weights=False, is_causal=True)
        x = self.self_attn_ln(x + attn_out)

        # Cross-attention
        cross_out, _ = self.cross_attn(x, audio_features, audio_features)
        x = self.cross_attn_ln(x + cross_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.mlp_ln(x + mlp_out)

        return x


# Display architecture overview
def show_architecture_info():
    """Information about Whisper model sizes"""

    models_info = {
        "tiny": {
            "parameters": "39M",
            "n_audio_layer": 4,
            "n_text_layer": 4,
            "n_state": 384,
            "n_head": 6
        },
        "base": {
            "parameters": "74M",
            "n_audio_layer": 6,
            "n_text_layer": 6,
            "n_state": 512,
            "n_head": 8
        },
        "small": {
            "parameters": "244M",
            "n_audio_layer": 12,
            "n_text_layer": 12,
            "n_state": 768,
            "n_head": 12
        },
        "medium": {
            "parameters": "769M",
            "n_audio_layer": 24,
            "n_text_layer": 24,
            "n_state": 1024,
            "n_head": 16
        },
        "large": {
            "parameters": "1550M",
            "n_audio_layer": 32,
            "n_text_layer": 32,
            "n_state": 1280,
            "n_head": 20
        }
    }

    print("Whisper Model sizecomparison")
    print("=" * 70)
    print(f"{'Model':<10} {'Parameters':<12} {'Layers':<10} {'State':<10} {'Heads':<10}")
    print("=" * 70)

    for model, info in models_info.items():
        print(f"{model:<10} {info['parameters']:<12} "
              f"{info['n_audio_layer']}/{info['n_text_layer']:<10} "
              f"{info['n_state']:<10} {info['n_head']:<10}")

    print()
    print("Recommended usage:")
    print("- tiny/base: Real-time processing, resource-constrained environments")
    print("- small: Balanced approach, general purpose")
    print("- medium/large: High accuracy requirements, multilingual processing")

show_architecture_info()

5. Practical ASR Systems

5.1 Fine-tuning Whisper

By fine-tuning Whisper for specific domains or languages, even higher accuracy can be achieved.

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0
# - transformers>=4.30.0

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, Audio
import torch

def finetune_whisper_japanese():
    """
    Whisper fine-tuning for Japanese speech recognition
    """

    # Load model and processor
    model_name = "openai/whisper-small"
    processor = WhisperProcessor.from_pretrained(model_name, language="Japanese", task="transcribe")
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    # Japanese-specific configuration
    model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
        language="Japanese",
        task="transcribe"
    )
    model.config.suppress_tokens = []

    # Prepare dataset (example: Common Voice Japanese)
    # dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ja", split="train[:100]")

    # Dummy dataset (use actual datasets like the one above in practice)
    def prepare_dataset(batch):
        """Dataset preprocessing"""
        # Resample audio to 16kHz
        audio = batch["audio"]

        # Convert to Mel Spectrogram
        batch["input_features"] = processor(
            audio["array"],
            sampling_rate=audio["sampling_rate"],
            return_tensors="pt"
        ).input_features[0]

        # Prepare labels
        batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids

        return batch

    # Apply preprocessing to dataset
    # dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)

    # Data collator
    from dataclasses import dataclass
    from typing import Any, Dict, List, Union

    @dataclass
    class DataCollatorSpeechSeq2SeqWithPadding:
        """Speech data collator"""

        processor: Any

        def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
            # Pad input features
            input_features = [{"input_features": feature["input_features"]} for feature in features]
            batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

            # Pad labels
            label_features = [{"input_ids": feature["labels"]} for feature in features]
            labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

            # Replace padding tokens with -100 (ignored in loss calculation)
            labels = labels_batch["input_ids"].masked_fill(
                labels_batch.attention_mask.ne(1), -100
            )

            # Remove bos token if present
            if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
                labels = labels[:, 1:]

            batch["labels"] = labels

            return batch

    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # Evaluation metrics
    import evaluate

    metric = evaluate.load("wer")  # Word Error Rate

    def compute_metrics(pred):
        """Calculate evaluation metrics"""
        pred_ids = pred.predictions
        label_ids = pred.label_ids

        # Replace -100 with pad_token_id
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

        # Decode
        pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        # Calculate WER
        wer = metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # Training configuration
    training_args = Seq2SeqTrainingArguments(
        output_dir="./whisper-japanese",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        learning_rate=1e-5,
        warmup_steps=500,
        num_train_epochs=3,
        evaluation_strategy="steps",
        eval_steps=1000,
        save_steps=1000,
        logging_steps=100,
        generation_max_length=225,
        predict_with_generate=True,
        fp16=True,  # Mixed precision training
        push_to_hub=False,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
    )

    # Initialize trainer
    # trainer = Seq2SeqTrainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=dataset["train"],
    #     eval_dataset=dataset["test"],
    #     data_collator=data_collator,
    #     compute_metrics=compute_metrics,
    #     tokenizer=processor.feature_extractor,
    # )

    # Start training
    # trainer.train()

    # Save model
    # model.save_pretrained("./whisper-japanese-finetuned")
    # processor.save_pretrained("./whisper-japanese-finetuned")

    print("✓ Fine-tuning configuration completed")
    print("Use datasets like Common Voice for actual training")

    return model, processor

# Execution example
model, processor = finetune_whisper_japanese()

5.2 Real-time Speech Recognition Application

Build an application that transcribes audio input from a microphone in real-time.

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0

import pyaudio
import numpy as np
import whisper
import queue
import threading
from collections import deque

class RealtimeASR:
    """Real-time speech recognition system"""

    def __init__(self, model_name="base", language="ja"):
        """
        Args:
            model_name: Whisper model size
            language: Recognition language
        """
        # Load Whisper model
        print(f"Loading Whisper model '{model_name}'...")
        self.model = whisper.load_model(model_name)
        self.language = language

        # Audio settings
        self.RATE = 16000  # Sampling rate
        self.CHUNK = 1024  # Buffer size
        self.CHANNELS = 1  # Mono
        self.FORMAT = pyaudio.paInt16

        # Audio buffer
        self.audio_queue = queue.Queue()
        self.audio_buffer = deque(maxlen=30)  # 30-second buffer

        # PyAudio initialization
        self.audio = pyaudio.PyAudio()

        # State management
        self.is_recording = False
        self.transcription_thread = None

        print("✓ Real-time ASR initialization completed")

    def audio_callback(self, in_data, frame_count, time_info, status):
        """Audio input callback"""
        if self.is_recording:
            # Add audio data to queue
            audio_data = np.frombuffer(in_data, dtype=np.int16)
            self.audio_queue.put(audio_data)

        return (in_data, pyaudio.paContinue)

    def start_recording(self):
        """Start recording"""
        self.is_recording = True

        # Open audio stream
        self.stream = self.audio.open(
            format=self.FORMAT,
            channels=self.CHANNELS,
            rate=self.RATE,
            input=True,
            frames_per_buffer=self.CHUNK,
            stream_callback=self.audio_callback
        )

        self.stream.start_stream()

        # Start transcription thread
        self.transcription_thread = threading.Thread(target=self.transcribe_loop)
        self.transcription_thread.start()

        print("🎤 Recording started...")

    def stop_recording(self):
        """Stop recording"""
        self.is_recording = False

        if hasattr(self, 'stream'):
            self.stream.stop_stream()
            self.stream.close()

        if self.transcription_thread:
            self.transcription_thread.join()

        print("âšī¸  Recording stopped")

    def transcribe_loop(self):
        """Transcription loop"""
        print("📝 Starting transcription...")

        while self.is_recording:
            # Collect audio data (1 second worth)
            audio_chunks = []
            for _ in range(int(self.RATE / self.CHUNK)):
                try:
                    chunk = self.audio_queue.get(timeout=0.1)
                    audio_chunks.append(chunk)
                except queue.Empty:
                    continue

            if not audio_chunks:
                continue

            # Concatenate audio data
            audio_data = np.concatenate(audio_chunks)
            self.audio_buffer.append(audio_data)

            # Get audio from buffer (5 seconds worth)
            if len(self.audio_buffer) >= 5:
                # Use latest 5 seconds
                audio_segment = np.concatenate(list(self.audio_buffer)[-5:])

                # Normalize
                audio_segment = audio_segment.astype(np.float32) / 32768.0

                # Transcription
                try:
                    result = self.model.transcribe(
                        audio_segment,
                        language=self.language,
                        task="transcribe",
                        fp16=False,
                        temperature=0.0,
                        no_speech_threshold=0.6
                    )

                    text = result["text"].strip()
                    if text:
                        print(f"Recognition result: {text}")

                except Exception as e:
                    print(f"Error: {e}")

    def __del__(self):
        """Cleanup"""
        if hasattr(self, 'audio'):
            self.audio.terminate()


# Usage example
def realtime_asr_demo():
    """Real-time ASR demo"""

    # Initialize ASR system
    asr = RealtimeASR(model_name="base", language="ja")

    try:
        # Start recording
        asr.start_recording()

        # Record for 10 seconds
        import time
        print("Please speak for 10 seconds...")
        time.sleep(10)

        # Stop recording
        asr.stop_recording()

    except KeyboardInterrupt:
        print("\nInterrupted")
        asr.stop_recording()

    print("✓ Demo completed")


# Batch processing version (from file)
def batch_transcribe_with_speaker_diarization(audio_file):
    """
    Speech recognition with speaker diarization
    (using libraries like pyannote.audio)
    """
    import whisper

    # Whisper transcription
    model = whisper.load_model("medium")
    result = model.transcribe(
        audio_file,
        language="ja",
        word_timestamps=True
    )

    # Speaker diarization (dummy implementation)
    # Use libraries like pyannote.audio in practice
    print("=" * 50)
    print("Transcription result (with speakers):")
    print("=" * 50)

    current_speaker = "Speaker 1"
    for i, segment in enumerate(result["segments"]):
        # Simple speaker switch detection (use actual speaker diarization model in practice)
        if i > 0 and segment["start"] - result["segments"][i-1]["end"] > 2.0:
            current_speaker = "Speaker 2" if current_speaker == "Speaker 1" else "Speaker 1"

        start = segment["start"]
        end = segment["end"]
        text = segment["text"]

        print(f"[{current_speaker}] [{start:.2f}s - {end:.2f}s]")
        print(f"  {text}")
        print()

    return result


# Note: pyaudio installation is required for actual execution
# pip install pyaudio
#
# For macOS:
# brew install portaudio
# pip install pyaudio

print("Real-time ASR system implementation example displayed")
print("Requires 'pyaudio' installation for execution")

5.3 Complete ASR Application

Build a complete speech recognition application with a web interface.

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0

import gradio as gr
import whisper
import numpy as np
from pathlib import Path

class ASRApplication:
    """Web-based speech recognition application"""

    def __init__(self):
        """Application initialization"""
        self.models = {}
        self.current_model = None

        # Available models
        self.available_models = {
            "tiny": "Fastest (39M parameters)",
            "base": "Fast (74M parameters)",
            "small": "Balanced (244M parameters)",
            "medium": "High accuracy (769M parameters)",
            "large": "Highest accuracy (1550M parameters)"
        }

        # Supported languages
        self.languages = {
            "Auto-detect": None,
            "Japanese": "ja",
            "English": "en",
            "Chinese": "zh",
            "Korean": "ko",
            "Spanish": "es",
            "French": "fr",
            "German": "de"
        }

    def load_model(self, model_name):
        """Load model (with caching)"""
        if model_name not in self.models:
            print(f"Loading model '{model_name}'...")
            self.models[model_name] = whisper.load_model(model_name)
            print(f"✓ Model '{model_name}' loading completed")

        self.current_model = self.models[model_name]
        return self.current_model

    def transcribe_audio(self, audio_file, model_name, language, task,
                        include_timestamps, beam_size, temperature):
        """
        Transcribe audio file

        Args:
            audio_file: Audio file path
            model_name: Model to use
            language: Recognition language
            task: transcribe or translate
            include_timestamps: Whether to include timestamps
            beam_size: Beam search width
            temperature: Sampling temperature
        """
        if audio_file is None:
            return "Please upload an audio file", ""

        try:
            # Load model
            model = self.load_model(model_name)

            # Transcription
            result = model.transcribe(
                audio_file,
                language=self.languages.get(language),
                task=task,
                beam_size=beam_size,
                temperature=temperature,
                word_timestamps=include_timestamps
            )

            # Basic transcription result
            transcription = result["text"]

            # Detailed information
            details = self._format_details(result, include_timestamps)

            return transcription, details

        except Exception as e:
            return f"An error occurred: {str(e)}", ""

    def _format_details(self, result, include_timestamps):
        """Format detailed information"""
        details = []

        # Detected language
        if "language" in result:
            details.append(f"Detected language: {result['language']}")

        # Segment information
        if include_timestamps and "segments" in result:
            details.append("\n" + "=" * 50)
            details.append("Segment details:")
            details.append("=" * 50)

            for i, segment in enumerate(result["segments"], 1):
                start = segment["start"]
                end = segment["end"]
                text = segment["text"]

                details.append(f"\n[{i}] {start:.2f}s - {end:.2f}s")
                details.append(f"    {text}")

                # Word-level timestamps
                if "words" in segment:
                    details.append("    Words:")
                    for word_info in segment["words"]:
                        word = word_info["word"]
                        w_start = word_info["start"]
                        w_end = word_info["end"]
                        details.append(f"      - {word:20s} [{w_start:.2f}s - {w_end:.2f}s]")

        return "\n".join(details)

    def create_interface(self):
        """Create Gradio interface"""

        with gr.Blocks(title="AI Speech Recognition System") as interface:
            gr.Markdown(
                """
                # đŸŽ™ī¸ AI Speech Recognition System

                High-accuracy speech recognition system using Whisper.
                Upload an audio file or record with microphone to perform transcription.
                """
            )

            with gr.Row():
                with gr.Column(scale=1):
                    # Input controls
                    audio_input = gr.Audio(
                        sources=["upload", "microphone"],
                        type="filepath",
                        label="Audio input"
                    )

                    model_selector = gr.Dropdown(
                        choices=list(self.available_models.keys()),
                        value="base",
                        label="Model selection",
                        info="Select accuracy vs speed tradeoff"
                    )

                    language_selector = gr.Dropdown(
                        choices=list(self.languages.keys()),
                        value="Auto-detect",
                        label="Language"
                    )

                    task_selector = gr.Radio(
                        choices=["transcribe", "translate"],
                        value="transcribe",
                        label="Task",
                        info="transcribe: Transcription in same language / translate: Translate to English"
                    )

                    with gr.Accordion("Advanced settings", open=False):
                        include_timestamps = gr.Checkbox(
                            label="Include timestamps",
                            value=True
                        )

                        beam_size = gr.Slider(
                            minimum=1,
                            maximum=10,
                            value=5,
                            step=1,
                            label="Beam size",
                            info="Larger values improve accuracy but increase computation time"
                        )

                        temperature = gr.Slider(
                            minimum=0.0,
                            maximum=1.0,
                            value=0.0,
                            step=0.1,
                            label="Temperature",
                            info="0: Deterministic, >0: Adds randomness"
                        )

                    transcribe_btn = gr.Button("Start transcription", variant="primary")

                with gr.Column(scale=2):
                    # Output
                    transcription_output = gr.Textbox(
                        label="Transcription result",
                        lines=5,
                        max_lines=10
                    )

                    details_output = gr.Textbox(
                        label="Detailed information",
                        lines=15,
                        max_lines=30
                    )

            # Event handler
            transcribe_btn.click(
                fn=self.transcribe_audio,
                inputs=[
                    audio_input,
                    model_selector,
                    language_selector,
                    task_selector,
                    include_timestamps,
                    beam_size,
                    temperature
                ],
                outputs=[transcription_output, details_output]
            )

            # Usage instructions
            gr.Markdown(
                """
                ## How to Use

                1. **Audio input**: Upload a file or record with microphone button
                2. **Model selection**: Select model size based on your needs
                   - Real-time processing: tiny/base
                   - General use: small
                   - High accuracy needed: medium/large
                3. **Language selection**: Auto-detect or select specific language
                4. **Start transcription**: Click button to begin processing

                ## Model Information

                | Model | Parameters | Speed | Accuracy | Recommended Use |
                |--------|------------|-------|----------|-----------------|
                | tiny   | 39M        | ★★★★★ | ★★☆☆☆ | Real-time processing |
                | base   | 74M        | ★★★★☆ | ★★★☆☆ | Fast processing |
                | small  | 244M       | ★★★☆☆ | ★★★★☆ | Balanced |
                | medium | 769M       | ★★☆☆☆ | ★★★★★ | High accuracy |
                | large  | 1550M      | ★☆☆☆☆ | ★★★★★ | Highest accuracy |
                """
            )

        return interface

    def launch(self, share=False):
        """Launch application"""
        interface = self.create_interface()
        interface.launch(share=share)


# Execute application
if __name__ == "__main__":
    app = ASRApplication()

    print("=" * 50)
    print("Starting AI Speech Recognition System...")
    print("=" * 50)

    # Launch application
    # app.launch(share=False)

    # Note: Gradio installation required
    # pip install gradio

    print("✓ Application setup completed")
    print("Requires 'gradio' installation for execution")
    print("Install with: pip install gradio")

Practice Problems

Problem 1: Understanding CTC Loss

Problem: Explain why CTC can learn without alignment information and describe the role of the Blank token.

Sample Answer:

CTC can learn without explicit alignment information by marginalizing the probabilities of all possible alignment paths. Specifically:

Problem 2: Differences Between Attention Mechanism and CTC

Problem: Explain the main differences between Attention-based models and CTC-based models from the perspectives of architecture and learning.

Sample Answer:

Aspect CTC Attention-based
architecture Encoder + Linear classifier Encoder + Attention + Decoder
alignment Monotonic Flexible (determined by Attention)
language model Conditional independence (external LM required) Integrated in Decoder
Computational Cost Low High
Long-range Dependencies Weak Strong
Problem 3: RNN-T Implementation

Problem: Explain the roles of the three main components of RNN-Transducer (Encoder, Prediction Network, Joint Network) and show a simple implementation.

Answer: Refer to the RNN-T implementation in the main text. Role of each component:

Problem 4: Fine-tuning Whisper

Problem: Describe the main considerations when fine-tuning Whisper for specific domains (e.g., medical, legal).

Sample Answer:

  1. Dataset:
    • Collect domain-specific audio data
    • Accurate transcription of technical terms
    • Diverse speakers and acoustic conditions
  2. Vocabulary expansion:
    • Add domain-specific terms to tokenizer
    • Handle abbreviations and specialized notation
  3. Learning rate and regularization:
    • Train carefully with low learning rate (around 1e-5)
    • Dropout to prevent overfitting
  4. Evaluation:
    • Evaluate WER on domain-specific test sets
    • Evaluate recognition accuracy of technical terms separately
Problem 5: Optimizing Real-time ASR

Problem: List three methods to optimize the latency-accuracy tradeoff in real-time speech recognition systems.

Sample Answer:

  1. Model size selection:
    • Use lightweight models (tiny/base) for real-time processing
    • Use model distillation for downsizing if needed
  2. Streaming processing:
    • Use streaming-capable architectures like RNN-T
    • Balance chunk size and recognition accuracy
    • Limit lookahead
  3. Hardware optimization:
    • Utilize GPU/TPU
    • Accelerate inference with quantization (e.g., INT8)
    • Utilize batch processing

Summary

In this chapter, we learned the main techniques for speech recognition using deep learning:

By combining these techniques, you can build highly accurate speech recognition systems for various scenarios. In the next chapter, we will learn about speech synthesis (TTS) and voice conversion.

References

Disclaimer