BERT and Bidirectional Models
BERT (Bidirectional Encoder Representations from Transformers) revolutionized NLP by introducing deep bidirectional pre-training. Unlike GPT's left-to-right approach, BERT reads text in both directions simultaneously, enabling richer contextual understanding.
The Bidirectional Revolution
Bidirectional Context: The ability to use information from both before and after a token when building its representation, as opposed to unidirectional models that only see previous context. This enables richer understanding for tasks where the full input is available.
Why Bidirectional Matters
"""
Unidirectional (GPT) vs Bidirectional (BERT) context:
Sentence: "The bank of the river was muddy."
GPT (left-to-right):
- Processing "bank": only sees "The"
- Cannot use "river" to disambiguate meaning
BERT (bidirectional):
- Processing "bank": sees both "The" and "of the river was muddy"
- Can correctly understand "bank" as riverbank, not financial institution
"""
import torch
import torch.nn as nn
def demonstrate_context_importance():
"""Show how bidirectional context helps disambiguation"""
examples = [
{
"sentence": "The bank can guarantee deposits.",
"word": "bank",
"left_context": "The",
"right_context": "can guarantee deposits",
"meaning": "financial institution"
},
{
"sentence": "The bank was full of flowers.",
"word": "bank",
"left_context": "The",
"right_context": "was full of flowers",
"meaning": "riverbank/slope"
}
]
print("Bidirectional Context Disambiguation:\n")
for ex in examples:
print(f"Sentence: {ex['sentence']}")
print(f"Word: '{ex['word']}'")
print(f"Left-only context: '{ex['left_context']}' → ambiguous")
print(f"With right context: '{ex['right_context']}'")
print(f"Correct meaning: {ex['meaning']}\n")
demonstrate_context_importance()
Key Insight: Bidirectional models can use future context to better understand the present, making them superior for understanding tasks but unsuitable for generation (which requires autoregressive left-to-right processing).
BERT Architecture
Model Design
"""
BERT Model Specifications:
BERT-Base:
- Parameters: 110M
- Layers: 12 encoder blocks
- Hidden size: 768
- Attention heads: 12
- Max sequence length: 512
- Vocab size: 30,000 (WordPiece)
BERT-Large:
- Parameters: 340M
- Layers: 24 encoder blocks
- Hidden size: 1024
- Attention heads: 16
- Max sequence length: 512
"""
class BERTEmbedding(nn.Module):
"""BERT's three-part embedding layer"""
def __init__(self, vocab_size, hidden_size, max_len=512,
num_segments=2, dropout=0.1):
super().__init__()
# Token embeddings (WordPiece)
self.token_embed = nn.Embedding(vocab_size, hidden_size)
# Position embeddings (learned, not sinusoidal)
self.position_embed = nn.Embedding(max_len, hidden_size)
# Segment embeddings (for sentence pairs)
self.segment_embed = nn.Embedding(num_segments, hidden_size)
self.norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, token_ids, segment_ids=None):
batch_size, seq_len = token_ids.shape
# Generate position IDs
position_ids = torch.arange(seq_len, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Default segment IDs (all zeros)
if segment_ids is None:
segment_ids = torch.zeros_like(token_ids)
# Sum all three embeddings
embeddings = (
self.token_embed(token_ids) +
self.position_embed(position_ids) +
self.segment_embed(segment_ids)
)
return self.dropout(self.norm(embeddings))
class BERTEncoderBlock(nn.Module):
"""BERT transformer encoder block"""
def __init__(self, hidden_size, num_heads, ff_size=None, dropout=0.1):
super().__init__()
ff_size = ff_size or 4 * hidden_size
# Multi-head self-attention
self.attention = nn.MultiheadAttention(
hidden_size, num_heads, dropout=dropout, batch_first=True
)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(hidden_size, ff_size),
nn.GELU(), # BERT uses GELU activation
nn.Linear(ff_size, hidden_size),
nn.Dropout(dropout)
)
# Layer normalization
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
# Feed-forward with residual
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
class BERTModel(nn.Module):
"""Complete BERT model architecture"""
def __init__(self, vocab_size=30000, hidden_size=768, num_layers=12,
num_heads=12, max_len=512):
super().__init__()
# Embeddings
self.embeddings = BERTEmbedding(vocab_size, hidden_size, max_len)
# Transformer encoder blocks
self.encoder_blocks = nn.ModuleList([
BERTEncoderBlock(hidden_size, num_heads)
for _ in range(num_layers)
])
self.pooler = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Tanh()
)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
# Embeddings
x = self.embeddings(input_ids, segment_ids)
# Encoder blocks
for block in self.encoder_blocks:
x = block(x, attention_mask)
# Pooled output (first token [CLS])
pooled_output = self.pooler(x[:, 0])
return x, pooled_output
# Model size calculation
model = BERTModel(num_layers=12, hidden_size=768)
total_params = sum(p.numel() for p in model.parameters())
print(f"BERT-Base parameters: {total_params:,}") # ~110M
[CLS] Token: BERT adds a special [CLS] token at the beginning of every sequence. The final hidden state of this token is used as the aggregate sequence representation for classification tasks.
Masked Language Modeling (MLM)
Masked Language Modeling (MLM): A pre-training objective where random tokens in the input are masked (hidden), and the model learns to predict the original tokens based on bidirectional context, enabling deep bidirectional representations.
The core pre-training task that enables bidirectional learning.
MLM Implementation
"""
Masked Language Modeling Strategy:
1. Randomly select 15% of tokens
2. Of selected tokens:
- 80% replace with [MASK]
- 10% replace with random token
- 10% keep unchanged
This prevents the model from only learning about [MASK] tokens.
"""
class MaskedLanguageModel(nn.Module):
"""MLM prediction head for BERT pre-training"""
def __init__(self, hidden_size, vocab_size):
super().__init__()
self.transform = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.LayerNorm(hidden_size)
)
self.decoder = nn.Linear(hidden_size, vocab_size)
def forward(self, hidden_states):
# Transform hidden states
hidden_states = self.transform(hidden_states)
# Predict tokens
predictions = self.decoder(hidden_states)
return predictions
class MLMTrainer:
"""Handles MLM training logic"""
def __init__(self, vocab_size, mask_token_id=103, mask_prob=0.15):
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.mask_prob = mask_prob
def create_mlm_batch(self, input_ids):
"""
Create masked inputs and labels for MLM
Args:
input_ids: Original token IDs [batch_size, seq_len]
Returns:
masked_input_ids: Input with some tokens masked
labels: Original tokens at masked positions
"""
batch_size, seq_len = input_ids.shape
# Clone input
masked_input = input_ids.clone()
labels = input_ids.clone()
# Create mask for 15% of tokens (excluding special tokens)
probability_matrix = torch.full(input_ids.shape, self.mask_prob)
# Don't mask special tokens [CLS]=101, [SEP]=102, [PAD]=0
special_tokens_mask = (
(input_ids == 101) | (input_ids == 102) | (input_ids == 0)
)
probability_matrix.masked_fill_(special_tokens_mask, 0.0)
# Select tokens to mask
masked_indices = torch.bernoulli(probability_matrix).bool()
# Only compute loss on masked tokens
labels[~masked_indices] = -100
# 80% of time: replace with [MASK]
indices_replaced = (
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() &
masked_indices
)
masked_input[indices_replaced] = self.mask_token_id
# 10% of time: replace with random token
indices_random = (
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() &
masked_indices &
~indices_replaced
)
random_tokens = torch.randint(
self.vocab_size, input_ids.shape, dtype=torch.long
)
masked_input[indices_random] = random_tokens[indices_random]
# 10% of time: keep original token (do nothing)
return masked_input, labels
def demonstrate_mlm():
"""Demonstrate MLM masking"""
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
sentence = "The quick brown fox jumps over the lazy dog."
tokens = tokenizer.tokenize(sentence)
print("Original sentence:", sentence)
print("Tokens:", tokens, "\n")
# Simulate masking
masked_version = tokens.copy()
masked_positions = [2, 4, 8] # brown, jumps, lazy
for pos in masked_positions:
masked_version[pos] = '[MASK]'
print("Masked version:", ' '.join(masked_version))
print("\nMLM task: Predict the original tokens at [MASK] positions")
print("Model sees context from both directions to predict:")
for pos in masked_positions:
left_context = ' '.join(tokens[:pos])
right_context = ' '.join(tokens[pos+1:])
print(f"\n Position {pos}: '{tokens[pos]}'")
print(f" Left context: ...{left_context}")
print(f" Right context: {right_context}...")
demonstrate_mlm()
Training MLM
# Complete MLM training loop
def train_mlm_epoch(model, mlm_head, dataloader, optimizer, device):
"""Train BERT with MLM objective"""
model.train()
mlm_head.train()
total_loss = 0
criterion = nn.CrossEntropyLoss(ignore_index=-100)
mlm_trainer = MLMTrainer(vocab_size=30000)
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
# Create masked inputs
masked_input, labels = mlm_trainer.create_mlm_batch(input_ids)
masked_input = masked_input.to(device)
labels = labels.to(device)
# Forward pass
hidden_states, _ = model(masked_input)
predictions = mlm_head(hidden_states)
# Compute loss only on masked tokens
loss = criterion(
predictions.view(-1, predictions.size(-1)),
labels.view(-1)
)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Example: Predict masked words
def predict_masked_tokens():
"""Use pre-trained BERT to predict masked words"""
from transformers import BertForMaskedLM, BertTokenizer
import torch
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Example sentence with mask
text = "The capital of France is [MASK]."
inputs = tokenizer(text, return_tensors='pt')
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits
# Get the predicted token for [MASK]
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = predictions[0, mask_token_index].argmax(axis=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print(f"Input: {text}")
print(f"Predicted word: {predicted_token}")
# Get top 5 predictions
top5_tokens = predictions[0, mask_token_index].topk(5).indices[0].tolist()
top5_words = [tokenizer.decode([token]) for token in top5_tokens]
print(f"Top 5 predictions: {top5_words}")
predict_masked_tokens()
Why 80-10-10? Using [MASK] 100% of the time would create a mismatch between pre-training and fine-tuning (where [MASK] never appears). The 80-10-10 split helps the model learn robust representations.
Next Sentence Prediction (NSP)
Next Sentence Prediction (NSP): A binary classification pre-training task where the model predicts whether two sentences appeared consecutively in the original text or were randomly paired, intended to help learn sentence-level relationships.
BERT's second pre-training objective for understanding sentence relationships.
NSP Implementation
"""
Next Sentence Prediction (NSP):
Task: Given two sentences A and B, predict if B actually follows A
in the original document.
Input format:
[CLS] Sentence A [SEP] Sentence B [SEP]
Labels:
- IsNext (1): B actually follows A
- NotNext (0): B is random sentence from corpus
"""
class NextSentencePrediction(nn.Module):
"""NSP prediction head"""
def __init__(self, hidden_size):
super().__init__()
self.classifier = nn.Linear(hidden_size, 2)
def forward(self, pooled_output):
"""
Args:
pooled_output: [CLS] token representation [batch_size, hidden_size]
Returns:
logits: [batch_size, 2] (NotNext, IsNext)
"""
return self.classifier(pooled_output)
class NSPDataset:
"""Create NSP training data"""
def __init__(self, documents):
self.documents = documents
def create_nsp_pair(self):
"""
Create one NSP training example
Returns:
sentence_a: First sentence
sentence_b: Second sentence
label: 1 if B follows A, 0 otherwise
"""
import random
# Select random document
doc = random.choice(self.documents)
sentences = doc.split('.')
if len(sentences) < 2:
return self.create_nsp_pair() # Try again
# Select sentence A
idx = random.randint(0, len(sentences) - 2)
sentence_a = sentences[idx].strip()
# 50% of time: next sentence (positive)
if random.random() < 0.5:
sentence_b = sentences[idx + 1].strip()
label = 1
# 50% of time: random sentence (negative)
else:
random_doc = random.choice(self.documents)
random_sentences = random_doc.split('.')
sentence_b = random.choice(random_sentences).strip()
label = 0
return sentence_a, sentence_b, label
def demonstrate_nsp():
"""Demonstrate NSP task"""
from transformers import BertForNextSentencePrediction, BertTokenizer
import torch
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Positive example (coherent)
prompt = "The sun rises in the east."
next_sentence = "It sets in the west."
encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
is_next_prob = probs[0, 1].item()
print("Example 1 (Coherent):")
print(f" Sentence A: {prompt}")
print(f" Sentence B: {next_sentence}")
print(f" IsNext probability: {is_next_prob:.3f}\n")
# Negative example (incoherent)
prompt = "The sun rises in the east."
random_sentence = "I love eating pizza on Fridays."
encoding = tokenizer(prompt, random_sentence, return_tensors='pt')
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
is_next_prob = probs[0, 1].item()
print("Example 2 (Incoherent):")
print(f" Sentence A: {prompt}")
print(f" Sentence B: {random_sentence}")
print(f" IsNext probability: {is_next_prob:.3f}")
demonstrate_nsp()
NSP Controversy: Later research (RoBERTa) showed that NSP might not be necessary and could even hurt performance. Modern BERT variants often remove or modify this objective.
Combined Pre-training
Joint MLM + NSP Training
class BERTPreTraining(nn.Module):
"""Complete BERT pre-training with MLM + NSP"""
def __init__(self, bert_model, vocab_size, hidden_size):
super().__init__()
self.bert = bert_model
# MLM head
self.mlm_head = MaskedLanguageModel(hidden_size, vocab_size)
# NSP head
self.nsp_head = NextSentencePrediction(hidden_size)
def forward(self, input_ids, segment_ids, attention_mask=None):
# Get BERT outputs
sequence_output, pooled_output = self.bert(
input_ids, segment_ids, attention_mask
)
# MLM predictions
mlm_predictions = self.mlm_head(sequence_output)
# NSP predictions
nsp_predictions = self.nsp_head(pooled_output)
return mlm_predictions, nsp_predictions
def train_bert_pretraining(model, dataloader, optimizer, device):
"""Combined MLM + NSP training"""
model.train()
mlm_criterion = nn.CrossEntropyLoss(ignore_index=-100)
nsp_criterion = nn.CrossEntropyLoss()
total_mlm_loss = 0
total_nsp_loss = 0
for batch in dataloader:
# Move to device
input_ids = batch['input_ids'].to(device)
segment_ids = batch['segment_ids'].to(device)
mlm_labels = batch['mlm_labels'].to(device)
nsp_labels = batch['nsp_labels'].to(device)
# Forward pass
mlm_preds, nsp_preds = model(input_ids, segment_ids)
# Compute losses
mlm_loss = mlm_criterion(
mlm_preds.view(-1, mlm_preds.size(-1)),
mlm_labels.view(-1)
)
nsp_loss = nsp_criterion(nsp_preds, nsp_labels)
# Combined loss
total_loss = mlm_loss + nsp_loss
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
total_mlm_loss += mlm_loss.item()
total_nsp_loss += nsp_loss.item()
return {
'mlm_loss': total_mlm_loss / len(dataloader),
'nsp_loss': total_nsp_loss / len(dataloader)
}
Fine-tuning BERT
Task-Specific Adaptations
"""
BERT Fine-tuning for Different Tasks:
1. Single Sentence Classification (sentiment, topic):
[CLS] sentence [SEP] → Use [CLS] representation
2. Sentence Pair Classification (entailment, similarity):
[CLS] sentence_a [SEP] sentence_b [SEP] → Use [CLS] representation
3. Question Answering (SQuAD):
[CLS] question [SEP] passage [SEP] → Predict start/end spans
4. Token Classification (NER, POS tagging):
Use each token's representation
"""
class BERTForSequenceClassification(nn.Module):
"""BERT for sentiment analysis, topic classification, etc."""
def __init__(self, bert_model, num_labels, hidden_size=768, dropout=0.1):
super().__init__()
self.bert = bert_model
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
# Get BERT outputs
_, pooled_output = self.bert(input_ids, segment_ids, attention_mask)
# Classification
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
class BERTForTokenClassification(nn.Module):
"""BERT for NER, POS tagging, etc."""
def __init__(self, bert_model, num_labels, hidden_size=768, dropout=0.1):
super().__init__()
self.bert = bert_model
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
# Get BERT outputs (all tokens)
sequence_output, _ = self.bert(input_ids, segment_ids, attention_mask)
# Token-level classification
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits
class BERTForQuestionAnswering(nn.Module):
"""BERT for extractive QA (SQuAD-style)"""
def __init__(self, bert_model, hidden_size=768):
super().__init__()
self.bert = bert_model
# Predict start and end positions
self.qa_outputs = nn.Linear(hidden_size, 2)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
# Get BERT outputs
sequence_output, _ = self.bert(input_ids, segment_ids, attention_mask)
# Predict start and end logits
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
return start_logits, end_logits
# Fine-tuning example
def finetune_bert_classification():
"""Fine-tune BERT for sentiment analysis"""
from transformers import BertForSequenceClassification, BertTokenizer, AdamW
model = BertForSequenceClassification.from_pretrained(
'bert-base-uncased',
num_labels=2 # binary classification
)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Example training data
texts = [
"This movie was absolutely fantastic!",
"I hated every minute of this film.",
"One of the best movies I've ever seen.",
"Terrible acting and boring plot."
]
labels = [1, 0, 1, 0] # 1=positive, 0=negative
# Tokenize
encodings = tokenizer(texts, padding=True, truncation=True,
return_tensors='pt')
# Training step
optimizer = AdamW(model.parameters(), lr=2e-5)
model.train()
outputs = model(**encodings, labels=torch.tensor(labels))
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Training loss: {loss.item():.4f}")
# Inference
model.eval()
test_text = "This is an amazing product!"
test_encoding = tokenizer(test_text, return_tensors='pt')
with torch.no_grad():
outputs = model(**test_encoding)
prediction = torch.argmax(outputs.logits, dim=1)
sentiment = "Positive" if prediction.item() == 1 else "Negative"
print(f"\nTest: '{test_text}'")
print(f"Predicted sentiment: {sentiment}")
finetune_bert_classification()
Transfer Learning Power: BERT's pre-trained representations enable strong performance even with limited task-specific data. Fine-tuning often requires only hundreds of labeled examples rather than thousands.
BERT Variants and Improvements
"""
BERT Variants Timeline:
1. BERT (2018) - Original bidirectional model
2. RoBERTa (2019) - Robustly Optimized BERT
- Removes NSP
- Dynamic masking
- Larger batches, more data
- Better performance
3. ALBERT (2019) - A Lite BERT
- Factorized embeddings
- Cross-layer parameter sharing
- Sentence-order prediction (SOP)
- Fewer parameters, similar performance
4. DistilBERT (2019) - Distilled version
- 40% smaller, 60% faster
- Retains 97% of BERT's performance
- Knowledge distillation
5. ELECTRA (2020) - Replaced token detection
- More efficient pre-training
- Generator-discriminator setup
- Better sample efficiency
"""
# RoBERTa improvements
roberta_improvements = """
RoBERTa (Robustly Optimized BERT Approach):
1. Dynamic Masking:
- BERT: Static masks during pre-processing
- RoBERTa: Generate masks dynamically during training
- Same data seen with different masks = more training signal
2. Remove NSP:
- Train only with MLM objective
- Use full sentences without sentence pairs
- Improves performance on downstream tasks
3. Larger Batches:
- BERT: 256 sequences
- RoBERTa: 8,000 sequences
- Better gradient estimates
4. More Data:
- BERT: 16GB (BookCorpus + Wikipedia)
- RoBERTa: 160GB (CC-News, OpenWebText, Stories)
5. Byte-Pair Encoding:
- BERT: WordPiece with 30K vocab
- RoBERTa: BPE with 50K vocab
- Better handling of rare words
"""
print(roberta_improvements)
# ALBERT optimizations
class ALBERTFactorizedEmbedding(nn.Module):
"""ALBERT's factorized embedding parameterization"""
def __init__(self, vocab_size, embedding_size, hidden_size):
super().__init__()
# Instead of vocab_size × hidden_size
# Use vocab_size × embedding_size + embedding_size × hidden_size
self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
self.projection = nn.Linear(embedding_size, hidden_size)
def forward(self, input_ids):
embeddings = self.word_embeddings(input_ids)
projected = self.projection(embeddings)
return projected
# Parameter comparison
def compare_bert_variants():
"""Compare parameter counts"""
variants = {
'BERT-Base': {
'params': '110M',
'layers': 12,
'hidden': 768,
'note': 'Original'
},
'RoBERTa-Base': {
'params': '125M',
'layers': 12,
'hidden': 768,
'note': 'Better training'
},
'ALBERT-Base': {
'params': '12M',
'layers': 12,
'hidden': 768,
'note': 'Parameter sharing'
},
'DistilBERT': {
'params': '66M',
'layers': 6,
'hidden': 768,
'note': 'Distilled, faster'
},
'ELECTRA-Base': {
'params': '110M',
'layers': 12,
'hidden': 768,
'note': 'Efficient pre-training'
}
}
print("BERT Variants Comparison:\n")
for name, specs in variants.items():
print(f"{name}:")
print(f" Parameters: {specs['params']}")
print(f" Layers: {specs['layers']}")
print(f" Hidden size: {specs['hidden']}")
print(f" Note: {specs['note']}\n")
compare_bert_variants()
Practice Exercise
# Exercise: Implement attention visualization for BERT
def visualize_bert_attention():
"""
Visualize BERT's attention patterns to understand
how it attends to different parts of the input
"""
from transformers import BertModel, BertTokenizer
import torch
import matplotlib.pyplot as plt
import numpy as np
model = BertModel.from_pretrained('bert-base-uncased',
output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "The cat sat on the mat."
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
# Get attention weights from last layer
attention = outputs.attentions[-1] # [batch, heads, seq, seq]
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Average across all attention heads
avg_attention = attention[0].mean(dim=0).numpy()
# Plot attention matrix
plt.figure(figsize=(10, 8))
plt.imshow(avg_attention, cmap='viridis')
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)
plt.xlabel('Key (attending to)')
plt.ylabel('Query (attending from)')
plt.title('BERT Attention Visualization (Layer 12, Average over Heads)')
plt.colorbar()
plt.tight_layout()
plt.savefig('bert_attention.png', dpi=150)
print(f"Tokens: {tokens}")
print(f"Attention shape: {avg_attention.shape}")
print("Attention visualization saved")
visualize_bert_attention()
# Exercise questions
exercise_questions = """
Practice Exercises:
1. Why does BERT use learned positional embeddings instead of sinusoidal?
2. Implement dynamic masking: Create 3 different masked versions of the
same sentence. How does this provide more training signal?
3. Compare BERT's [CLS] token approach to averaging all token embeddings
for sentence representation. Which is better and why?
4. Calculate: How many parameters does BERT-Base save by using
hidden_size=768 instead of 1024?
5. Design: Create a BERT-based model for detecting sarcasm.
What architecture modifications would you make?
"""
print(exercise_questions)
Quiz
Further Reading
- BERT: Pre-training of Deep Bidirectional Transformers (Original Paper)
- RoBERTa: A Robustly Optimized BERT Pretraining Approach
- ALBERT: A Lite BERT for Self-supervised Learning
- DistilBERT: A distilled version of BERT
- ELECTRA: Pre-training Text Encoders as Discriminators
- The Illustrated BERT (Jay Alammar)