Back
intermediate
Modern LLM Architectures

BERT and Bidirectional Models

Master BERT's bidirectional architecture, masked language modeling, and how it revolutionized NLP understanding tasks

18 min min read

BERT and Bidirectional Models

BERT (Bidirectional Encoder Representations from Transformers) revolutionized NLP by introducing deep bidirectional pre-training. Unlike GPT's left-to-right approach, BERT reads text in both directions simultaneously, enabling richer contextual understanding.

The Bidirectional Revolution

Bidirectional Context: The ability to use information from both before and after a token when building its representation, as opposed to unidirectional models that only see previous context. This enables richer understanding for tasks where the full input is available.

Why Bidirectional Matters

python
"""
Unidirectional (GPT) vs Bidirectional (BERT) context:

Sentence: "The bank of the river was muddy."

GPT (left-to-right):
- Processing "bank": only sees "The"
- Cannot use "river" to disambiguate meaning

BERT (bidirectional):
- Processing "bank": sees both "The" and "of the river was muddy"
- Can correctly understand "bank" as riverbank, not financial institution
"""

import torch
import torch.nn as nn

def demonstrate_context_importance():
    """Show how bidirectional context helps disambiguation"""

    examples = [
        {
            "sentence": "The bank can guarantee deposits.",
            "word": "bank",
            "left_context": "The",
            "right_context": "can guarantee deposits",
            "meaning": "financial institution"
        },
        {
            "sentence": "The bank was full of flowers.",
            "word": "bank",
            "left_context": "The",
            "right_context": "was full of flowers",
            "meaning": "riverbank/slope"
        }
    ]

    print("Bidirectional Context Disambiguation:\n")
    for ex in examples:
        print(f"Sentence: {ex['sentence']}")
        print(f"Word: '{ex['word']}'")
        print(f"Left-only context: '{ex['left_context']}' → ambiguous")
        print(f"With right context: '{ex['right_context']}'")
        print(f"Correct meaning: {ex['meaning']}\n")

demonstrate_context_importance()

Key Insight: Bidirectional models can use future context to better understand the present, making them superior for understanding tasks but unsuitable for generation (which requires autoregressive left-to-right processing).

BERT Architecture

Model Design

python
"""
BERT Model Specifications:

BERT-Base:
- Parameters: 110M
- Layers: 12 encoder blocks
- Hidden size: 768
- Attention heads: 12
- Max sequence length: 512
- Vocab size: 30,000 (WordPiece)

BERT-Large:
- Parameters: 340M
- Layers: 24 encoder blocks
- Hidden size: 1024
- Attention heads: 16
- Max sequence length: 512
"""

class BERTEmbedding(nn.Module):
    """BERT's three-part embedding layer"""

    def __init__(self, vocab_size, hidden_size, max_len=512,
                 num_segments=2, dropout=0.1):
        super().__init__()

        # Token embeddings (WordPiece)
        self.token_embed = nn.Embedding(vocab_size, hidden_size)

        # Position embeddings (learned, not sinusoidal)
        self.position_embed = nn.Embedding(max_len, hidden_size)

        # Segment embeddings (for sentence pairs)
        self.segment_embed = nn.Embedding(num_segments, hidden_size)

        self.norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, token_ids, segment_ids=None):
        batch_size, seq_len = token_ids.shape

        # Generate position IDs
        position_ids = torch.arange(seq_len, device=token_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)

        # Default segment IDs (all zeros)
        if segment_ids is None:
            segment_ids = torch.zeros_like(token_ids)

        # Sum all three embeddings
        embeddings = (
            self.token_embed(token_ids) +
            self.position_embed(position_ids) +
            self.segment_embed(segment_ids)
        )

        return self.dropout(self.norm(embeddings))

class BERTEncoderBlock(nn.Module):
    """BERT transformer encoder block"""

    def __init__(self, hidden_size, num_heads, ff_size=None, dropout=0.1):
        super().__init__()
        ff_size = ff_size or 4 * hidden_size

        # Multi-head self-attention
        self.attention = nn.MultiheadAttention(
            hidden_size, num_heads, dropout=dropout, batch_first=True
        )

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, ff_size),
            nn.GELU(),  # BERT uses GELU activation
            nn.Linear(ff_size, hidden_size),
            nn.Dropout(dropout)
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed-forward with residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x

class BERTModel(nn.Module):
    """Complete BERT model architecture"""

    def __init__(self, vocab_size=30000, hidden_size=768, num_layers=12,
                 num_heads=12, max_len=512):
        super().__init__()

        # Embeddings
        self.embeddings = BERTEmbedding(vocab_size, hidden_size, max_len)

        # Transformer encoder blocks
        self.encoder_blocks = nn.ModuleList([
            BERTEncoderBlock(hidden_size, num_heads)
            for _ in range(num_layers)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh()
        )

    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        # Embeddings
        x = self.embeddings(input_ids, segment_ids)

        # Encoder blocks
        for block in self.encoder_blocks:
            x = block(x, attention_mask)

        # Pooled output (first token [CLS])
        pooled_output = self.pooler(x[:, 0])

        return x, pooled_output

# Model size calculation
model = BERTModel(num_layers=12, hidden_size=768)
total_params = sum(p.numel() for p in model.parameters())
print(f"BERT-Base parameters: {total_params:,}")  # ~110M

[CLS] Token: BERT adds a special [CLS] token at the beginning of every sequence. The final hidden state of this token is used as the aggregate sequence representation for classification tasks.

Masked Language Modeling (MLM)

Masked Language Modeling (MLM): A pre-training objective where random tokens in the input are masked (hidden), and the model learns to predict the original tokens based on bidirectional context, enabling deep bidirectional representations.

The core pre-training task that enables bidirectional learning.

MLM Implementation

python
"""
Masked Language Modeling Strategy:

1. Randomly select 15% of tokens
2. Of selected tokens:
   - 80% replace with [MASK]
   - 10% replace with random token
   - 10% keep unchanged

This prevents the model from only learning about [MASK] tokens.
"""

class MaskedLanguageModel(nn.Module):
    """MLM prediction head for BERT pre-training"""

    def __init__(self, hidden_size, vocab_size):
        super().__init__()

        self.transform = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size)
        )

        self.decoder = nn.Linear(hidden_size, vocab_size)

    def forward(self, hidden_states):
        # Transform hidden states
        hidden_states = self.transform(hidden_states)

        # Predict tokens
        predictions = self.decoder(hidden_states)

        return predictions

class MLMTrainer:
    """Handles MLM training logic"""

    def __init__(self, vocab_size, mask_token_id=103, mask_prob=0.15):
        self.vocab_size = vocab_size
        self.mask_token_id = mask_token_id
        self.mask_prob = mask_prob

    def create_mlm_batch(self, input_ids):
        """
        Create masked inputs and labels for MLM

        Args:
            input_ids: Original token IDs [batch_size, seq_len]

        Returns:
            masked_input_ids: Input with some tokens masked
            labels: Original tokens at masked positions
        """
        batch_size, seq_len = input_ids.shape

        # Clone input
        masked_input = input_ids.clone()
        labels = input_ids.clone()

        # Create mask for 15% of tokens (excluding special tokens)
        probability_matrix = torch.full(input_ids.shape, self.mask_prob)

        # Don't mask special tokens [CLS]=101, [SEP]=102, [PAD]=0
        special_tokens_mask = (
            (input_ids == 101) | (input_ids == 102) | (input_ids == 0)
        )
        probability_matrix.masked_fill_(special_tokens_mask, 0.0)

        # Select tokens to mask
        masked_indices = torch.bernoulli(probability_matrix).bool()

        # Only compute loss on masked tokens
        labels[~masked_indices] = -100

        # 80% of time: replace with [MASK]
        indices_replaced = (
            torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() &
            masked_indices
        )
        masked_input[indices_replaced] = self.mask_token_id

        # 10% of time: replace with random token
        indices_random = (
            torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() &
            masked_indices &
            ~indices_replaced
        )
        random_tokens = torch.randint(
            self.vocab_size, input_ids.shape, dtype=torch.long
        )
        masked_input[indices_random] = random_tokens[indices_random]

        # 10% of time: keep original token (do nothing)

        return masked_input, labels

def demonstrate_mlm():
    """Demonstrate MLM masking"""
    from transformers import BertTokenizer

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    sentence = "The quick brown fox jumps over the lazy dog."
    tokens = tokenizer.tokenize(sentence)

    print("Original sentence:", sentence)
    print("Tokens:", tokens, "\n")

    # Simulate masking
    masked_version = tokens.copy()
    masked_positions = [2, 4, 8]  # brown, jumps, lazy

    for pos in masked_positions:
        masked_version[pos] = '[MASK]'

    print("Masked version:", ' '.join(masked_version))
    print("\nMLM task: Predict the original tokens at [MASK] positions")
    print("Model sees context from both directions to predict:")

    for pos in masked_positions:
        left_context = ' '.join(tokens[:pos])
        right_context = ' '.join(tokens[pos+1:])
        print(f"\n  Position {pos}: '{tokens[pos]}'")
        print(f"    Left context:  ...{left_context}")
        print(f"    Right context: {right_context}...")

demonstrate_mlm()

Training MLM

python
# Complete MLM training loop
def train_mlm_epoch(model, mlm_head, dataloader, optimizer, device):
    """Train BERT with MLM objective"""
    model.train()
    mlm_head.train()

    total_loss = 0
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    mlm_trainer = MLMTrainer(vocab_size=30000)

    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)

        # Create masked inputs
        masked_input, labels = mlm_trainer.create_mlm_batch(input_ids)
        masked_input = masked_input.to(device)
        labels = labels.to(device)

        # Forward pass
        hidden_states, _ = model(masked_input)
        predictions = mlm_head(hidden_states)

        # Compute loss only on masked tokens
        loss = criterion(
            predictions.view(-1, predictions.size(-1)),
            labels.view(-1)
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# Example: Predict masked words
def predict_masked_tokens():
    """Use pre-trained BERT to predict masked words"""
    from transformers import BertForMaskedLM, BertTokenizer
    import torch

    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Example sentence with mask
    text = "The capital of France is [MASK]."
    inputs = tokenizer(text, return_tensors='pt')

    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits

    # Get the predicted token for [MASK]
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
    predicted_token_id = predictions[0, mask_token_index].argmax(axis=-1)
    predicted_token = tokenizer.decode(predicted_token_id)

    print(f"Input: {text}")
    print(f"Predicted word: {predicted_token}")

    # Get top 5 predictions
    top5_tokens = predictions[0, mask_token_index].topk(5).indices[0].tolist()
    top5_words = [tokenizer.decode([token]) for token in top5_tokens]

    print(f"Top 5 predictions: {top5_words}")

predict_masked_tokens()

Why 80-10-10? Using [MASK] 100% of the time would create a mismatch between pre-training and fine-tuning (where [MASK] never appears). The 80-10-10 split helps the model learn robust representations.

Next Sentence Prediction (NSP)

Next Sentence Prediction (NSP): A binary classification pre-training task where the model predicts whether two sentences appeared consecutively in the original text or were randomly paired, intended to help learn sentence-level relationships.

BERT's second pre-training objective for understanding sentence relationships.

NSP Implementation

python
"""
Next Sentence Prediction (NSP):

Task: Given two sentences A and B, predict if B actually follows A
in the original document.

Input format:
[CLS] Sentence A [SEP] Sentence B [SEP]

Labels:
- IsNext (1): B actually follows A
- NotNext (0): B is random sentence from corpus
"""

class NextSentencePrediction(nn.Module):
    """NSP prediction head"""

    def __init__(self, hidden_size):
        super().__init__()
        self.classifier = nn.Linear(hidden_size, 2)

    def forward(self, pooled_output):
        """
        Args:
            pooled_output: [CLS] token representation [batch_size, hidden_size]

        Returns:
            logits: [batch_size, 2] (NotNext, IsNext)
        """
        return self.classifier(pooled_output)

class NSPDataset:
    """Create NSP training data"""

    def __init__(self, documents):
        self.documents = documents

    def create_nsp_pair(self):
        """
        Create one NSP training example

        Returns:
            sentence_a: First sentence
            sentence_b: Second sentence
            label: 1 if B follows A, 0 otherwise
        """
        import random

        # Select random document
        doc = random.choice(self.documents)
        sentences = doc.split('.')

        if len(sentences) < 2:
            return self.create_nsp_pair()  # Try again

        # Select sentence A
        idx = random.randint(0, len(sentences) - 2)
        sentence_a = sentences[idx].strip()

        # 50% of time: next sentence (positive)
        if random.random() < 0.5:
            sentence_b = sentences[idx + 1].strip()
            label = 1
        # 50% of time: random sentence (negative)
        else:
            random_doc = random.choice(self.documents)
            random_sentences = random_doc.split('.')
            sentence_b = random.choice(random_sentences).strip()
            label = 0

        return sentence_a, sentence_b, label

def demonstrate_nsp():
    """Demonstrate NSP task"""
    from transformers import BertForNextSentencePrediction, BertTokenizer
    import torch

    model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Positive example (coherent)
    prompt = "The sun rises in the east."
    next_sentence = "It sets in the west."

    encoding = tokenizer(prompt, next_sentence, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits

    probs = torch.softmax(logits, dim=1)
    is_next_prob = probs[0, 1].item()

    print("Example 1 (Coherent):")
    print(f"  Sentence A: {prompt}")
    print(f"  Sentence B: {next_sentence}")
    print(f"  IsNext probability: {is_next_prob:.3f}\n")

    # Negative example (incoherent)
    prompt = "The sun rises in the east."
    random_sentence = "I love eating pizza on Fridays."

    encoding = tokenizer(prompt, random_sentence, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits

    probs = torch.softmax(logits, dim=1)
    is_next_prob = probs[0, 1].item()

    print("Example 2 (Incoherent):")
    print(f"  Sentence A: {prompt}")
    print(f"  Sentence B: {random_sentence}")
    print(f"  IsNext probability: {is_next_prob:.3f}")

demonstrate_nsp()

NSP Controversy: Later research (RoBERTa) showed that NSP might not be necessary and could even hurt performance. Modern BERT variants often remove or modify this objective.

Combined Pre-training

Joint MLM + NSP Training

python
class BERTPreTraining(nn.Module):
    """Complete BERT pre-training with MLM + NSP"""

    def __init__(self, bert_model, vocab_size, hidden_size):
        super().__init__()

        self.bert = bert_model

        # MLM head
        self.mlm_head = MaskedLanguageModel(hidden_size, vocab_size)

        # NSP head
        self.nsp_head = NextSentencePrediction(hidden_size)

    def forward(self, input_ids, segment_ids, attention_mask=None):
        # Get BERT outputs
        sequence_output, pooled_output = self.bert(
            input_ids, segment_ids, attention_mask
        )

        # MLM predictions
        mlm_predictions = self.mlm_head(sequence_output)

        # NSP predictions
        nsp_predictions = self.nsp_head(pooled_output)

        return mlm_predictions, nsp_predictions

def train_bert_pretraining(model, dataloader, optimizer, device):
    """Combined MLM + NSP training"""
    model.train()

    mlm_criterion = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_criterion = nn.CrossEntropyLoss()

    total_mlm_loss = 0
    total_nsp_loss = 0

    for batch in dataloader:
        # Move to device
        input_ids = batch['input_ids'].to(device)
        segment_ids = batch['segment_ids'].to(device)
        mlm_labels = batch['mlm_labels'].to(device)
        nsp_labels = batch['nsp_labels'].to(device)

        # Forward pass
        mlm_preds, nsp_preds = model(input_ids, segment_ids)

        # Compute losses
        mlm_loss = mlm_criterion(
            mlm_preds.view(-1, mlm_preds.size(-1)),
            mlm_labels.view(-1)
        )
        nsp_loss = nsp_criterion(nsp_preds, nsp_labels)

        # Combined loss
        total_loss = mlm_loss + nsp_loss

        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        total_mlm_loss += mlm_loss.item()
        total_nsp_loss += nsp_loss.item()

    return {
        'mlm_loss': total_mlm_loss / len(dataloader),
        'nsp_loss': total_nsp_loss / len(dataloader)
    }

Fine-tuning BERT

Task-Specific Adaptations

python
"""
BERT Fine-tuning for Different Tasks:

1. Single Sentence Classification (sentiment, topic):
   [CLS] sentence [SEP] → Use [CLS] representation

2. Sentence Pair Classification (entailment, similarity):
   [CLS] sentence_a [SEP] sentence_b [SEP] → Use [CLS] representation

3. Question Answering (SQuAD):
   [CLS] question [SEP] passage [SEP] → Predict start/end spans

4. Token Classification (NER, POS tagging):
   Use each token's representation
"""

class BERTForSequenceClassification(nn.Module):
    """BERT for sentiment analysis, topic classification, etc."""

    def __init__(self, bert_model, num_labels, hidden_size=768, dropout=0.1):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        # Get BERT outputs
        _, pooled_output = self.bert(input_ids, segment_ids, attention_mask)

        # Classification
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

class BERTForTokenClassification(nn.Module):
    """BERT for NER, POS tagging, etc."""

    def __init__(self, bert_model, num_labels, hidden_size=768, dropout=0.1):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        # Get BERT outputs (all tokens)
        sequence_output, _ = self.bert(input_ids, segment_ids, attention_mask)

        # Token-level classification
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        return logits

class BERTForQuestionAnswering(nn.Module):
    """BERT for extractive QA (SQuAD-style)"""

    def __init__(self, bert_model, hidden_size=768):
        super().__init__()
        self.bert = bert_model

        # Predict start and end positions
        self.qa_outputs = nn.Linear(hidden_size, 2)

    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        # Get BERT outputs
        sequence_output, _ = self.bert(input_ids, segment_ids, attention_mask)

        # Predict start and end logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

# Fine-tuning example
def finetune_bert_classification():
    """Fine-tune BERT for sentiment analysis"""
    from transformers import BertForSequenceClassification, BertTokenizer, AdamW

    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=2  # binary classification
    )
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Example training data
    texts = [
        "This movie was absolutely fantastic!",
        "I hated every minute of this film.",
        "One of the best movies I've ever seen.",
        "Terrible acting and boring plot."
    ]
    labels = [1, 0, 1, 0]  # 1=positive, 0=negative

    # Tokenize
    encodings = tokenizer(texts, padding=True, truncation=True,
                         return_tensors='pt')

    # Training step
    optimizer = AdamW(model.parameters(), lr=2e-5)

    model.train()
    outputs = model(**encodings, labels=torch.tensor(labels))

    loss = outputs.loss
    loss.backward()
    optimizer.step()

    print(f"Training loss: {loss.item():.4f}")

    # Inference
    model.eval()
    test_text = "This is an amazing product!"
    test_encoding = tokenizer(test_text, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**test_encoding)
        prediction = torch.argmax(outputs.logits, dim=1)

    sentiment = "Positive" if prediction.item() == 1 else "Negative"
    print(f"\nTest: '{test_text}'")
    print(f"Predicted sentiment: {sentiment}")

finetune_bert_classification()

Transfer Learning Power: BERT's pre-trained representations enable strong performance even with limited task-specific data. Fine-tuning often requires only hundreds of labeled examples rather than thousands.

BERT Variants and Improvements

python
"""
BERT Variants Timeline:

1. BERT (2018) - Original bidirectional model
2. RoBERTa (2019) - Robustly Optimized BERT
   - Removes NSP
   - Dynamic masking
   - Larger batches, more data
   - Better performance

3. ALBERT (2019) - A Lite BERT
   - Factorized embeddings
   - Cross-layer parameter sharing
   - Sentence-order prediction (SOP)
   - Fewer parameters, similar performance

4. DistilBERT (2019) - Distilled version
   - 40% smaller, 60% faster
   - Retains 97% of BERT's performance
   - Knowledge distillation

5. ELECTRA (2020) - Replaced token detection
   - More efficient pre-training
   - Generator-discriminator setup
   - Better sample efficiency
"""

# RoBERTa improvements
roberta_improvements = """
RoBERTa (Robustly Optimized BERT Approach):

1. Dynamic Masking:
   - BERT: Static masks during pre-processing
   - RoBERTa: Generate masks dynamically during training
   - Same data seen with different masks = more training signal

2. Remove NSP:
   - Train only with MLM objective
   - Use full sentences without sentence pairs
   - Improves performance on downstream tasks

3. Larger Batches:
   - BERT: 256 sequences
   - RoBERTa: 8,000 sequences
   - Better gradient estimates

4. More Data:
   - BERT: 16GB (BookCorpus + Wikipedia)
   - RoBERTa: 160GB (CC-News, OpenWebText, Stories)

5. Byte-Pair Encoding:
   - BERT: WordPiece with 30K vocab
   - RoBERTa: BPE with 50K vocab
   - Better handling of rare words
"""

print(roberta_improvements)

# ALBERT optimizations
class ALBERTFactorizedEmbedding(nn.Module):
    """ALBERT's factorized embedding parameterization"""

    def __init__(self, vocab_size, embedding_size, hidden_size):
        super().__init__()

        # Instead of vocab_size × hidden_size
        # Use vocab_size × embedding_size + embedding_size × hidden_size
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.projection = nn.Linear(embedding_size, hidden_size)

    def forward(self, input_ids):
        embeddings = self.word_embeddings(input_ids)
        projected = self.projection(embeddings)
        return projected

# Parameter comparison
def compare_bert_variants():
    """Compare parameter counts"""

    variants = {
        'BERT-Base': {
            'params': '110M',
            'layers': 12,
            'hidden': 768,
            'note': 'Original'
        },
        'RoBERTa-Base': {
            'params': '125M',
            'layers': 12,
            'hidden': 768,
            'note': 'Better training'
        },
        'ALBERT-Base': {
            'params': '12M',
            'layers': 12,
            'hidden': 768,
            'note': 'Parameter sharing'
        },
        'DistilBERT': {
            'params': '66M',
            'layers': 6,
            'hidden': 768,
            'note': 'Distilled, faster'
        },
        'ELECTRA-Base': {
            'params': '110M',
            'layers': 12,
            'hidden': 768,
            'note': 'Efficient pre-training'
        }
    }

    print("BERT Variants Comparison:\n")
    for name, specs in variants.items():
        print(f"{name}:")
        print(f"  Parameters: {specs['params']}")
        print(f"  Layers: {specs['layers']}")
        print(f"  Hidden size: {specs['hidden']}")
        print(f"  Note: {specs['note']}\n")

compare_bert_variants()

Practice Exercise

python
# Exercise: Implement attention visualization for BERT
def visualize_bert_attention():
    """
    Visualize BERT's attention patterns to understand
    how it attends to different parts of the input
    """
    from transformers import BertModel, BertTokenizer
    import torch
    import matplotlib.pyplot as plt
    import numpy as np

    model = BertModel.from_pretrained('bert-base-uncased',
                                     output_attentions=True)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    text = "The cat sat on the mat."
    inputs = tokenizer(text, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**inputs)

    # Get attention weights from last layer
    attention = outputs.attentions[-1]  # [batch, heads, seq, seq]

    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    # Average across all attention heads
    avg_attention = attention[0].mean(dim=0).numpy()

    # Plot attention matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(avg_attention, cmap='viridis')
    plt.xticks(range(len(tokens)), tokens, rotation=90)
    plt.yticks(range(len(tokens)), tokens)
    plt.xlabel('Key (attending to)')
    plt.ylabel('Query (attending from)')
    plt.title('BERT Attention Visualization (Layer 12, Average over Heads)')
    plt.colorbar()
    plt.tight_layout()
    plt.savefig('bert_attention.png', dpi=150)

    print(f"Tokens: {tokens}")
    print(f"Attention shape: {avg_attention.shape}")
    print("Attention visualization saved")

visualize_bert_attention()

# Exercise questions
exercise_questions = """
Practice Exercises:

1. Why does BERT use learned positional embeddings instead of sinusoidal?

2. Implement dynamic masking: Create 3 different masked versions of the
   same sentence. How does this provide more training signal?

3. Compare BERT's [CLS] token approach to averaging all token embeddings
   for sentence representation. Which is better and why?

4. Calculate: How many parameters does BERT-Base save by using
   hidden_size=768 instead of 1024?

5. Design: Create a BERT-based model for detecting sarcasm.
   What architecture modifications would you make?
"""

print(exercise_questions)

Quiz

Further Reading