* 2. Include this script: * 3. Create charts with minimal configuration - colors are auto-applied! */ (function() { 'use strict'; // ========================================================================== // READ COLORS FROM CSS CUSTOM PROPERTIES // This ensures chart colors stay in sync with the theme // ========================================================================== /** * Get a CSS custom property value from :root */ function getCSSVar(name, fallback = '') { if (typeof getComputedStyle === 'undefined') return fallback; const value = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); return value || fallback; } /** * Build palette from CSS custom properties (with fallbacks) */ function buildPaletteFromCSS() { return { // Primary brand colors dartmouthGreen: getCSSVar('--dartmouth-green', '#00693e'), textPrimary: getCSSVar('--text-primary', '#0a2518'), textSecondary: getCSSVar('--text-secondary', '#0a3d23'), // Chart colors (from CSS --chart-color-N variables) chartColors: [ getCSSVar('--chart-color-1', '#00693e'), getCSSVar('--chart-color-2', '#267aba'), getCSSVar('--chart-color-3', '#ffa00f'), getCSSVar('--chart-color-4', '#9d162e'), getCSSVar('--chart-color-5', '#8a6996'), getCSSVar('--chart-color-6', '#a5d75f'), getCSSVar('--chart-color-7', '#003c73'), getCSSVar('--chart-color-8', '#d94415'), getCSSVar('--chart-color-9', '#643c20'), getCSSVar('--chart-color-10', '#c4dd88'), getCSSVar('--chart-color-11', '#f5dc69'), getCSSVar('--chart-color-12', '#424141'), ], // Background colors (semi-transparent versions) chartBgColors: [ getCSSVar('--chart-bg-1', 'rgba(0, 105, 62, 0.5)'), getCSSVar('--chart-bg-2', 'rgba(38, 122, 186, 0.5)'), getCSSVar('--chart-bg-3', 'rgba(255, 160, 15, 0.5)'), getCSSVar('--chart-bg-4', 'rgba(157, 22, 46, 0.5)'), getCSSVar('--chart-bg-5', 'rgba(138, 105, 150, 0.5)'), getCSSVar('--chart-bg-6', 'rgba(165, 215, 95, 0.5)'), ], // Semantic colors positive: getCSSVar('--chart-positive', '#00693e'), negative: getCSSVar('--chart-negative', '#9d162e'), neutral: getCSSVar('--chart-neutral', '#424141'), highlight: getCSSVar('--chart-highlight', '#ffa00f'), // Grid and axis colors gridLight: getCSSVar('--chart-grid-light', 'rgba(0, 105, 62, 0.1)'), gridMedium: getCSSVar('--chart-grid-medium', 'rgba(0, 105, 62, 0.15)'), gridDark: getCSSVar('--chart-grid-dark', 'rgba(0, 105, 62, 0.2)'), axisColor: getCSSVar('--chart-axis-color', '#0a2518'), // Font fontFamily: getCSSVar('--chart-font-family', "'Avenir LT Std', 'Avenir', 'Avenir Next', -apple-system, BlinkMacSystemFont, sans-serif"), }; } // Initialize palette (will be populated when DOM is ready) let CDL_PALETTE = null; // For convenience, expose primary chart colors array let CHART_COLORS = null; // ========================================================================== // FONT CONFIGURATION // Responsive font sizes based on typical Marp slide dimensions (1280x720) // ========================================================================== const FONT_CONFIG = { sizes: { title: 22, // Chart title subtitle: 18, // Subtitle legend: 16, // Legend labels axisTitle: 18, // Axis titles axisTicks: 16, // Axis tick labels tooltip: 14, // Tooltip text dataLabels: 14, // Data labels on charts }, weight: { normal: 400, medium: 500, bold: 600, }, }; // ========================================================================== // HELPER FUNCTIONS // ========================================================================== /** * Ensure palette is initialized */ function ensurePalette() { if (!CDL_PALETTE) { CDL_PALETTE = buildPaletteFromCSS(); CHART_COLORS = CDL_PALETTE.chartColors; } return CDL_PALETTE; } /** * Get color for a dataset at given index * Cycles through palette if more datasets than colors */ function getColor(index) { ensurePalette(); return CHART_COLORS[index % CHART_COLORS.length]; } /** * Get color with alpha transparency */ function getColorWithAlpha(color, alpha) { // Handle hex colors if (color.startsWith('#')) { const r = parseInt(color.slice(1, 3), 16); const g = parseInt(color.slice(3, 5), 16); const b = parseInt(color.slice(5, 7), 16); return `rgba(${r}, ${g}, ${b}, ${alpha})`; } // Handle rgba colors if (color.startsWith('rgba')) { return color.replace(/[\d.]+\)$/, `${alpha})`); } return color; } /** * Generate colors for all datasets in chart data * Automatically assigns colors if not specified */ function autoAssignColors(data, chartType) { if (!data || !data.datasets) return data; data.datasets.forEach((dataset, index) => { const baseColor = getColor(index); // Only assign colors if not already specified switch (chartType) { case 'bar': case 'horizontalBar': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'line': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.1); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 3; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 6; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } if (dataset.tension === undefined) { dataset.tension = 0.3; } break; case 'scatter': case 'bubble': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 15; } if (dataset.pointHoverRadius === undefined) { dataset.pointHoverRadius = 18; } break; case 'pie': case 'doughnut': case 'polarArea': // For pie charts, we need multiple colors for one dataset if (!dataset.backgroundColor) { const numItems = dataset.data ? dataset.data.length : 6; dataset.backgroundColor = []; for (let i = 0; i < numItems; i++) { dataset.backgroundColor.push(getColor(i)); } } if (!dataset.borderColor) { dataset.borderColor = '#d8d8d8'; // Slide background } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'radar': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.2); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 4; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } break; default: // Generic color assignment if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } } }); return data; } // ========================================================================== // CHART.JS GLOBAL DEFAULTS // ========================================================================== function applyGlobalDefaults() { if (typeof Chart === 'undefined') { console.warn('Chart.js not loaded. chart-defaults.js requires Chart.js to be loaded first.'); return false; } // Ensure palette is loaded from CSS const palette = ensurePalette(); // Font defaults Chart.defaults.font.family = palette.fontFamily; Chart.defaults.font.size = FONT_CONFIG.sizes.axisTicks; Chart.defaults.color = palette.textPrimary; // Responsive defaults Chart.defaults.responsive = true; Chart.defaults.maintainAspectRatio = false; // Animation (subtle) Chart.defaults.animation.duration = 400; // Plugin defaults // Legend Chart.defaults.plugins.legend.labels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, weight: FONT_CONFIG.weight.normal, }; Chart.defaults.plugins.legend.labels.color = palette.textPrimary; Chart.defaults.plugins.legend.labels.usePointStyle = true; Chart.defaults.plugins.legend.labels.padding = 20; // Title Chart.defaults.plugins.title.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.title, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.title.color = palette.textPrimary; // Tooltip Chart.defaults.plugins.tooltip.backgroundColor = palette.textPrimary; Chart.defaults.plugins.tooltip.titleFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.tooltip.bodyFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, }; Chart.defaults.plugins.tooltip.cornerRadius = 4; Chart.defaults.plugins.tooltip.padding = 10; // Scale defaults (for cartesian charts) // These need to be applied per-scale type const scaleDefaults = { grid: { color: palette.gridLight, lineWidth: 1, }, border: { color: palette.gridDark, width: 1, }, ticks: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }, color: palette.textPrimary, }, title: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTitle, weight: FONT_CONFIG.weight.normal, }, color: palette.textPrimary, }, }; // Apply scale defaults to linear scale if (Chart.defaults.scales && Chart.defaults.scales.linear) { if (Chart.defaults.scales.linear.grid) Object.assign(Chart.defaults.scales.linear.grid, scaleDefaults.grid); if (Chart.defaults.scales.linear.border) Object.assign(Chart.defaults.scales.linear.border, scaleDefaults.border); if (Chart.defaults.scales.linear.ticks) Object.assign(Chart.defaults.scales.linear.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.linear.title) Object.assign(Chart.defaults.scales.linear.title, scaleDefaults.title); } // Apply scale defaults to category scale if (Chart.defaults.scales && Chart.defaults.scales.category) { if (Chart.defaults.scales.category.grid) Object.assign(Chart.defaults.scales.category.grid, scaleDefaults.grid); if (Chart.defaults.scales.category.border) Object.assign(Chart.defaults.scales.category.border, scaleDefaults.border); if (Chart.defaults.scales.category.ticks) Object.assign(Chart.defaults.scales.category.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.category.title) Object.assign(Chart.defaults.scales.category.title, scaleDefaults.title); } // Apply scale defaults to logarithmic scale if (Chart.defaults.scales && Chart.defaults.scales.logarithmic) { if (Chart.defaults.scales.logarithmic.grid) Object.assign(Chart.defaults.scales.logarithmic.grid, scaleDefaults.grid); if (Chart.defaults.scales.logarithmic.border) Object.assign(Chart.defaults.scales.logarithmic.border, scaleDefaults.border); if (Chart.defaults.scales.logarithmic.ticks) Object.assign(Chart.defaults.scales.logarithmic.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.logarithmic.title) Object.assign(Chart.defaults.scales.logarithmic.title, scaleDefaults.title); } // Apply scale defaults to radial scale (for radar charts) if (Chart.defaults.scales && Chart.defaults.scales.radialLinear) { if (Chart.defaults.scales.radialLinear.grid) Chart.defaults.scales.radialLinear.grid.color = palette.gridLight; if (Chart.defaults.scales.radialLinear.angleLines) Chart.defaults.scales.radialLinear.angleLines.color = palette.gridMedium; if (Chart.defaults.scales.radialLinear.pointLabels) { Chart.defaults.scales.radialLinear.pointLabels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }; Chart.defaults.scales.radialLinear.pointLabels.color = palette.textPrimary; } } return true; } // ========================================================================== // CHART WRAPPER FOR AUTO-STYLING // ========================================================================== /** * Wrap the Chart constructor to automatically apply CDL styling */ function wrapChartConstructor() { if (typeof Chart === 'undefined') return; const OriginalChart = Chart; // Create a wrapper that auto-applies colors window.Chart = function(ctx, config) { // Auto-assign colors if not specified if (config && config.data) { config.data = autoAssignColors(config.data, config.type); } // Merge default options for specific chart types if (config && config.options) { config.options = applyChartTypeDefaults(config.type, config.options); } // Call original constructor return new OriginalChart(ctx, config); }; // Copy static properties and methods Object.setPrototypeOf(window.Chart, OriginalChart); Object.assign(window.Chart, OriginalChart); // Preserve the prototype chain window.Chart.prototype = OriginalChart.prototype; } /** * Apply chart-type specific defaults */ function applyChartTypeDefaults(chartType, userOptions) { const options = { ...userOptions }; switch (chartType) { case 'bar': case 'horizontalBar': // Bar chart defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; // Hide x-axis grid for cleaner look if (options.scales.x.grid === undefined) { options.scales.x.grid = { display: false }; } break; case 'line': // Line chart defaults if (!options.interaction) { options.interaction = { intersect: false, mode: 'index' }; } break; case 'pie': case 'doughnut': // Pie/doughnut defaults if (!options.plugins) options.plugins = {}; if (options.plugins.legend === undefined) { const palette = ensurePalette(); options.plugins.legend = { position: 'right', labels: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, }, color: palette.textPrimary, padding: 15, }, }; } break; case 'radar': // Radar chart defaults - keep as-is, scale defaults applied globally break; case 'scatter': case 'bubble': // Scatter/bubble defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; break; } return options; } // ========================================================================== // CONVENIENCE FUNCTIONS FOR USERS // Exposed on window.CDLChart for easy access // ========================================================================== window.CDLChart = { // Color palette access (getters to ensure lazy initialization) get colors() { return ensurePalette().chartColors; }, get palette() { return ensurePalette(); }, // Get specific color by index getColor: getColor, // Get color with transparency getColorWithAlpha: getColorWithAlpha, // Get array of colors for a specific count getColors: function(count) { ensurePalette(); const result = []; for (let i = 0; i < count; i++) { result.push(getColor(i)); } return result; }, // Font configuration fonts: FONT_CONFIG, // Quick chart creation helpers // These create minimal config that auto-applies all styling /** * Create a simple bar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ bar: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'bar', data: { labels: labels, datasets: [{ data: data }], }, options: { plugins: { legend: { display: false } }, ...options, }, }); }, /** * Create a simple line chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ line: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'line', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, fill: ds.fill !== undefined ? ds.fill : true, })), }, options: options, }); }, /** * Create a simple pie chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ pie: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'pie', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a simple scatter chart * @param {string} canvasId - Canvas element ID * @param {Array} datasets - Array of {label, data: [{x, y}]} objects * @param {object} options - Optional overrides */ scatter: function(canvasId, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'scatter', data: { datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, /** * Create a doughnut chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ doughnut: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'doughnut', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a radar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ radar: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'radar', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, }; // ========================================================================== // INITIALIZATION // ========================================================================== function initialize() { // Wait for Chart.js to be available if (typeof Chart !== 'undefined') { applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully.'); return true; } else { // Chart.js not yet loaded - wait and retry let retries = 0; const maxRetries = 50; // 5 seconds max wait const checkInterval = setInterval(function() { retries++; if (typeof Chart !== 'undefined') { clearInterval(checkInterval); applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully (after waiting for Chart.js).'); } else if (retries >= maxRetries) { clearInterval(checkInterval); console.warn('Chart.js not found after waiting. CDL Chart defaults not applied.'); } }, 100); return false; } } // Initialize IMMEDIATELY - this must run BEFORE any chart creation scripts // Chart.js CDN should be loaded before this script initialize(); })();

Lecture 23: Implementing GPT from scratch

PSYC 51.17: Models of language and communication

Jeremy R. Manning
Dartmouth College
Winter 2026

Learning objectives

  1. Implement a complete GPT model from scratch in PyTorch (~30M parameters)
  2. Build each component: embeddings, masked attention, transformer block, LM head
  3. Write a training loop with AdamW, gradient clipping, and gradient accumulation
  4. Apply mixed precision training and learning rate scheduling for efficient training
  5. Implement text generation with temperature, top-k, and nucleus sampling

Announcements and roadmap

There are no classes February 23–27 (instructor away). Use this time to work on your final project and the optional Assignment 5 (Build GPT).

📓 Companion Notebook — build and train a mini-GPT step by step. All code from this lecture runs in the notebook.

What we are building today

We will implement every component of a GPT language model from scratch — small enough to train on a laptop, but architecturally identical to GPT-2:

  1. Tokenization with tiktoken (GPT-2's BPE tokenizer)
  2. Token + position embeddings
  3. Masked multi-head attention (causal mask)
  4. Transformer decoder blocks (pre-norm with residual connections)
  5. Language model head with weight tying
  6. Training loop with AdamW, gradient accumulation, and mixed precision
  7. Text generation with multiple sampling strategies

Setup and hyperparameters


1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.utils.data import Dataset, DataLoader
5import tiktoken  # OpenAI's BPE tokenizer
6
7device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
9config = {
10    'vocab_size': 50257,    # GPT-2 vocabulary size
11    'd_model': 384,         # Embedding dimension
12    'n_layers': 6,          # Number of transformer blocks
13    'n_heads': 6,           # Number of attention heads
14    'max_seq_len': 256,     # Maximum sequence length
15    'dropout': 0.1,
continued...

Setup and hyperparameters


16    'batch_size': 32,
17    'learning_rate': 3e-4,
18    'num_epochs': 10
19}
...continued

Tokenization with tiktoken


1tokenizer = tiktoken.get_encoding("gpt2")
2
3text = "Hello, how are you doing today?"
4tokens = tokenizer.encode(text)
5# [15496, 11, 703, 389, 345, 1804, 1909, 30]
6
7# Decode back to text
8decoded = tokenizer.decode(tokens)  # "Hello, how are you doing today?"
9
10# Inspect individual tokens
11for tid in tokens:
12    print(f"  {tid}: '{tokenizer.decode([tid])}'")

Creating a text dataset


1class TextDataset(Dataset):
2    def __init__(self, text_file, tokenizer, max_seq_len):
3        with open(text_file, 'r', encoding='utf-8') as f:
4            text = f.read()
5        self.tokens = tokenizer.encode(text)
6        self.max_seq_len = max_seq_len
7
8    def __len__(self):
9        return len(self.tokens) - self.max_seq_len
10
11    def __getitem__(self, idx):
12        chunk = self.tokens[idx : idx + self.max_seq_len + 1]
13        x = torch.tensor(chunk[:-1], dtype=torch.long)  # Input
continued...

Creating a text dataset


14        y = torch.tensor(chunk[1:],  dtype=torch.long)  # Target
15        return x, y
16
17dataset = TextDataset('shakespeare.txt', tokenizer, config['max_seq_len'])
18dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
...continued

The target at each position is the next token. For input [A, B, C, D], the targets are [B, C, D, E]. This is the autoregressive training signal.

Token and position embeddings


1class Embeddings(nn.Module):
2    def __init__(self, vocab_size, d_model, max_seq_len, dropout):
3        super().__init__()
4        self.token_embed = nn.Embedding(vocab_size, d_model)
5        self.pos_embed = nn.Embedding(max_seq_len, d_model)
6        self.dropout = nn.Dropout(dropout)
7
8    def forward(self, x):
9        seq_len = x.size(1)
10        tok_emb = self.token_embed(x)                        # (B, T, D)
11        pos_emb = self.pos_embed(torch.arange(seq_len, device=x.device))  # (T, D)
12        return self.dropout(tok_emb + pos_emb)               # Broadcasting adds positions

Token embeddings encode what; position embeddings encode where. Their sum gives the model both word identity and word order (Lecture 15).

Masked multi-head attention


1class MultiHeadAttention(nn.Module):
2    def __init__(self, d_model, n_heads, dropout):
3        super().__init__()
4        assert d_model % n_heads == 0
5        self.n_heads = n_heads
6        self.head_dim = d_model // n_heads
7
8        self.q_linear = nn.Linear(d_model, d_model)
9        self.k_linear = nn.Linear(d_model, d_model)
10        self.v_linear = nn.Linear(d_model, d_model)
11        self.out_linear = nn.Linear(d_model, d_model)
12        self.dropout = nn.Dropout(dropout)
continued...

Masked multi-head attention


13
14    def forward(self, x, mask=None):
15        B, T, D = x.shape
16        Q = self.q_linear(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
17        K = self.k_linear(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
18        V = self.v_linear(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
19        # Q, K, V shape: (B, n_heads, T, head_dim)
...continued

Attention computation


1        # Scaled dot-product attention
2        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
3
4        if mask is not None:
5            scores = scores.masked_fill(mask == 0, float('-inf'))
6
7        attn_weights = F.softmax(scores, dim=-1)
8        attn_weights = self.dropout(attn_weights)
9
10        # Apply attention to values and concatenate heads
11        out = torch.matmul(attn_weights, V)                  # (B, n_heads, T, head_dim)
12        out = out.transpose(1, 2).contiguous().view(B, T, D) # (B, T, D)
13        return self.out_linear(out)

1def create_causal_mask(seq_len, device):
2    return torch.tril(torch.ones(seq_len, seq_len, device=device))

Each row i has 1s at positions 0..i and 0s at positions i+1..T-1, blocking attention to future tokens.

Feed-forward network and transformer block


1class FeedForward(nn.Module):
2    def __init__(self, d_model, dropout):
3        super().__init__()
4        self.net = nn.Sequential(
5            nn.Linear(d_model, 4 * d_model),
6            nn.GELU(),                        # GPT uses GELU, not ReLU
7            nn.Linear(4 * d_model, d_model),
8            nn.Dropout(dropout)
9        )
10    def forward(self, x):
11        return self.net(x)
continued...

Feed-forward network and transformer block


12
13class TransformerBlock(nn.Module):
14    def __init__(self, d_model, n_heads, dropout):
15        super().__init__()
16        self.ln1 = nn.LayerNorm(d_model)
17        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
18        self.ln2 = nn.LayerNorm(d_model)
19        self.ffn = FeedForward(d_model, dropout)
20
21    def forward(self, x, mask):
22        x = x + self.attention(self.ln1(x), mask)   # Pre-norm + residual
...continued...

Feed-forward network and transformer block


23        x = x + self.ffn(self.ln2(x))               # Pre-norm + residual
24        return x
...continued

Complete GPT model


1class GPT(nn.Module):
2    def __init__(self, vocab_size, d_model, n_layers, n_heads, max_seq_len, dropout):
3        super().__init__()
4        self.max_seq_len = max_seq_len
5        self.embeddings = Embeddings(vocab_size, d_model, max_seq_len, dropout)
6        self.blocks = nn.ModuleList([
7            TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)
8        ])
9        self.ln_f = nn.LayerNorm(d_model)
10        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
11        self.apply(self._init_weights)
12
continued...

Complete GPT model


13    def _init_weights(self, module):
14        if isinstance(module, nn.Linear):
15            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
16            if module.bias is not None:
17                torch.nn.init.zeros_(module.bias)
18        elif isinstance(module, nn.Embedding):
19            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
...continued

GPT forward pass


1    def forward(self, x, targets=None):
2        seq_len = x.size(1)
3        mask = create_causal_mask(seq_len, x.device)
4
5        x = self.embeddings(x)
6        for block in self.blocks:
7            x = block(x, mask)
8        x = self.ln_f(x)
9        logits = self.lm_head(x)          # (batch, seq_len, vocab_size)
10
11        loss = None
12        if targets is not None:
13            loss = F.cross_entropy(
14                logits.view(-1, logits.size(-1)),
15                targets.view(-1)
16            )
continued...

GPT forward pass


17        return logits, loss
...continued

With d_model=384, n_layers=6, n_heads=6, and vocab_size=50257, this model has approximately 30 million parameters -- small enough to train on a single GPU in a few hours.

Training loop


1model = GPT(
2    config['vocab_size'], config['d_model'], config['n_layers'],
3    config['n_heads'], config['max_seq_len'], config['dropout']
4).to(device)
5
6optimizer = torch.optim.AdamW(
7    model.parameters(), lr=config['learning_rate'],
8    betas=(0.9, 0.95), weight_decay=0.1
9)
10
11model.train()
12for epoch in range(config['num_epochs']):
13    total_loss = 0
14    for x, y in dataloader:
continued...

Training loop


15        x, y = x.to(device), y.to(device)
16        logits, loss = model(x, targets=y)
17        optimizer.zero_grad()
18        loss.backward()
19        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
20        optimizer.step()
21        total_loss += loss.item()
22    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader):.4f}")
...continued

Weight tying


1class GPT(nn.Module):
2    def __init__(self, vocab_size, d_model, n_layers, n_heads, max_seq_len, dropout):
3        super().__init__()
4        self.embeddings = Embeddings(vocab_size, d_model, max_seq_len, dropout)
5        # ... blocks, ln_f as before ...
6        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
7
8        # Tie weights: lm_head uses same matrix as token embeddings
9        self.lm_head.weight = self.embeddings.token_embed.weight
10        # Saves ~20% parameters for large vocabularies!

Both layers map between token IDs and embedding space — just in opposite directions. Tying them forces consistent representations. Used in GPT-2, LLaMA, and most modern LLMs (Lecture 21).

Gradient accumulation


1accumulation_steps = 8  # Simulate 8x larger batch
2optimizer.zero_grad()
3
4for i, (x, y) in enumerate(dataloader):
5    x, y = x.to(device), y.to(device)
6    logits, loss = model(x, targets=y)
7    loss = loss / accumulation_steps  # Normalize
8    loss.backward()
9
10    if (i + 1) % accumulation_steps == 0:
11        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
12        optimizer.step()
13        optimizer.zero_grad()

Can't fit batch_size=256? Use batch_size=32 with 8 accumulation steps. Mathematically equivalent gradients, 8× less memory. Essential for training on consumer GPUs.

Learning rate scheduling


1import math
2
3def get_lr(step, warmup_steps=1000, max_steps=50000, max_lr=3e-4, min_lr=3e-5):
4    if step < warmup_steps:
5        return max_lr * step / warmup_steps          # Linear warmup
6    decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
7    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
8
9# Usage in training loop
10for step in range(max_steps):
11    lr = get_lr(step)
12    for param_group in optimizer.param_groups:
13        param_group['lr'] = lr
14    # ... training step ...

Warmup prevents early instability (random weights → wild gradients). Cosine decay gives diminishing returns gracefully. This is the schedule used by GPT-3, LLaMA, and most modern LLMs.

Mixed precision training


1scaler = torch.amp.GradScaler()
2
3for x, y in dataloader:
4    x, y = x.to(device), y.to(device)
5
6    with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
7        logits, loss = model(x, targets=y)
8
9    scaler.scale(loss).backward()
10    scaler.unscale_(optimizer)
11    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
12    scaler.step(optimizer)
13    scaler.update()
14    optimizer.zero_grad()

bfloat16 halves memory and doubles throughput on modern GPUs with minimal quality loss. The GradScaler handles numerical stability automatically. All serious training runs (nanoGPT, LLaMA, etc.) use mixed precision.

Training best practices

Technique Why it matters
Gradient clipping (max norm 1.0) Prevents exploding gradients that destabilize training
AdamW optimizer Decoupled weight decay; better than standard Adam for transformers
Learning rate warmup Gradually increase LR over first ~1,000 steps to avoid early instability
Cosine LR decay Smoothly reduce LR after warmup for fine-grained convergence
Mixed precision (float16/bfloat16) 2--3x speedup on modern GPUs with minimal quality loss

If loss is not decreasing: check learning rate (try 1e-4 to 3e-4), verify data pipeline outputs correct input/target pairs, and watch for NaN values. If you run out of memory: reduce batch size or sequence length first, then try gradient accumulation.

Greedy decoding


1@torch.no_grad()
2def generate_greedy(model, tokenizer, prompt, max_new_tokens=50):
3    model.eval()
4    tokens = tokenizer.encode(prompt)
5    x = torch.tensor([tokens], dtype=torch.long, device=device)
6
7    for _ in range(max_new_tokens):
8        x_crop = x[:, -model.max_seq_len:]
9        logits, _ = model(x_crop)
10        logits = logits[:, -1, :]                # Last position only
11        next_token = torch.argmax(logits, dim=-1, keepdim=True)
12        x = torch.cat([x, next_token], dim=1)
13
14    return tokenizer.decode(x[0].tolist())

Always picking argmax produces the same output every time and tends to get stuck in repetitive loops. Real applications use sampling to introduce controlled randomness.

Sampling strategies

  • Temperature (TT): Scale logits by 1/T1/T before softmax. T<1T < 1 sharpens the distribution (more conservative); T>1T > 1 flattens it (more creative).
  • Top-k: Sample only from the kk most probable tokens. Typical: k=40k = 40.
  • Nucleus (top-p): Sample from the smallest set of tokens whose cumulative probability exceeds pp. Typical: p=0.9p = 0.9 or 0.950.95.
Strategy Pros Cons
Greedy Deterministic, fast Repetitive, boring
Temperature Simple control knob Can produce nonsense at high TT
Top-k Filters unlikely tokens Fixed kk may be too broad or narrow
Nucleus (top-p) Adapts to distribution shape Slightly more complex

Implementing sampling


1def sample_next_token(logits, temperature=1.0, top_k=None, top_p=None):
2    logits = logits / temperature
3
4    if top_k is not None:
5        top_k = min(top_k, logits.size(-1))
6        threshold = torch.topk(logits, top_k)[0][..., -1, None]
7        logits[logits < threshold] = float('-inf')
8
9    if top_p is not None:
10        sorted_logits, sorted_idx = torch.sort(logits, descending=True)
11        cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
12        remove = cumprobs > top_p
13        remove[..., 1:] = remove[..., :-1].clone()
14        remove[..., 0] = False
15        logits[sorted_idx[remove]] = float('-inf')
16
continued...

Implementing sampling


17    probs = F.softmax(logits, dim=-1)
18    return torch.multinomial(probs, num_samples=1)
...continued

Complete generation function


1@torch.no_grad()
2def generate(model, tokenizer, prompt, max_new_tokens=100,
3             temperature=0.8, top_k=40, top_p=0.9):
4    model.eval()
5    tokens = tokenizer.encode(prompt)
6    x = torch.tensor([tokens], dtype=torch.long, device=device)
7
8    for _ in range(max_new_tokens):
9        x_crop = x[:, -model.max_seq_len:]
10        logits, _ = model(x_crop)
11        next_token = sample_next_token(
12            logits[0, -1, :], temperature=temperature, top_k=top_k, top_p=top_p
13        )
14        x = torch.cat([x, next_token.unsqueeze(0)], dim=1)
continued...

Complete generation function


15
16    return tokenizer.decode(x[0].tolist())
17
18# Generate text
19print(generate(model, tokenizer, "Once upon a time"))
...continued

Debugging and common issues

Problem Likely cause Fix
Loss not decreasing LR too high/low, data bug Try LR in [1e-4, 3e-4]; verify x/y offset
Out of memory Batch/sequence too large Reduce batch size; use gradient accumulation
Poor generation quality Undertrained Train longer; use more/better data
Repetitive output Greedy decoding or low temperature Use nucleus sampling (p=0.9p = 0.9)
NaN loss Numerical instability Add gradient clipping; check for empty batches

KV cache: fast generation

Without caching, generating token nn requires recomputing attention over all n1n-1 previous tokens. The KV cache stores previously computed key/value tensors and only computes Q/K/V for the new token.


1# Without KV cache: recompute everything each step → O(n²) per token
2for step in range(100):
3    logits = model(all_tokens[:step+1])  # Reprocesses all tokens
4
5# With KV cache: only process the NEW token → O(n) per token
6cache = {}
7for step in range(100):
8    logits, cache = model(new_token_only, past_kv=cache)
9    # cache stores K, V from all previous steps

KV caching makes generation 10–50× faster for long sequences. It's the reason ChatGPT responds in seconds, not minutes. The tradeoff is memory: the cache grows linearly with sequence length.

FlashAttention: memory-efficient attention

Standard attention materializes the full N×NN \times N attention matrix in GPU memory — O(N2)O(N^2) memory. FlashAttention computes attention in tiled blocks that fit in fast SRAM, never materializing the full matrix.

Metric Standard attention FlashAttention-3
Memory O(N2)O(N^2) O(N)O(N)
Speed Baseline 1.5–2× faster
Max sequence ~4K (memory limited) 128K+
Wall-clock for 8K seq ~50ms ~25ms

FlashAttention-3 adds hardware-aware pipelining for H100 GPUs, approaching 75% of theoretical FLOPS.

GPT-2 vs modern decoders


1# GPT-2 style (what we built today)
2self.ln = nn.LayerNorm(d_model)              # LayerNorm
3self.ffn = nn.Sequential(
4    nn.Linear(d_model, 4 * d_model),
5    nn.GELU(),                                # GELU activation
6    nn.Linear(4 * d_model, d_model))
7self.pos_embed = nn.Embedding(max_len, d)    # Learned absolute positions
8
9# Modern style (LLaMA, Mistral, etc.)
10self.ln = RMSNorm(d_model)                   # RMSNorm: 10-15% faster
11self.ffn = SwiGLU(d_model, int(8/3 * d_model))  # SwiGLU: ~1% better
12self.pos_embed = None  # RoPE applied inside attention (extrapolates to any length)

Our mini-GPT is architecturally identical to GPT-2. To reach LLaMA-class performance, swap in RMSNorm, SwiGLU, RoPE, and GQA — all incremental changes to the same basic structure. The conceptual framework you built today is the same one powering frontier models.

How our mini-GPT compares to nanoGPT

Feature Our mini-GPT nanoGPT GPT-2 (124M)
Parameters ~30M ~124M 124M
Context length 256 1024 1024
Training data Shakespeare OpenWebText WebText
Attention Standard FlashAttention Standard
Precision float32 bfloat16 float32
Weight tying Yes (add it!) Yes Yes
Time to train ~1 hour (GPU) ~4 hours (A100) Days (256 TPUs)

Our mini-GPT is architecturally identical to GPT-2. To reach nanoGPT-class performance: scale up d_model to 768 and n_layers to 12, add FlashAttention, use mixed precision, and train on a larger corpus. The conceptual framework you built today is the same one powering frontier models — just smaller.

Further reading

Karpathy, "Let's build GPT" Step-by-step video tutorial (2h) — the inspiration for this lecture.

nanoGPT Clean, minimal GPT-2 implementation in ~300 lines of PyTorch.

Dao et al. (2022, NeurIPS) "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" — O(N)O(N) memory attention.

Shah et al. (2024, arXiv) "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" — Latest version for H100s.

Radford et al. (2018) "Improving Language Understanding by Generative Pre-Training" — The original GPT paper.

Questions?

📧 Email
💬 Discord

Week 9 (after break): Agents and tool use -- giving language models the ability to act