* 2. Include this script: * 3. Create charts with minimal configuration - colors are auto-applied! */ (function() { 'use strict'; // ========================================================================== // READ COLORS FROM CSS CUSTOM PROPERTIES // This ensures chart colors stay in sync with the theme // ========================================================================== /** * Get a CSS custom property value from :root */ function getCSSVar(name, fallback = '') { if (typeof getComputedStyle === 'undefined') return fallback; const value = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); return value || fallback; } /** * Build palette from CSS custom properties (with fallbacks) */ function buildPaletteFromCSS() { return { // Primary brand colors dartmouthGreen: getCSSVar('--dartmouth-green', '#00693e'), textPrimary: getCSSVar('--text-primary', '#0a2518'), textSecondary: getCSSVar('--text-secondary', '#0a3d23'), // Chart colors (from CSS --chart-color-N variables) chartColors: [ getCSSVar('--chart-color-1', '#00693e'), getCSSVar('--chart-color-2', '#267aba'), getCSSVar('--chart-color-3', '#ffa00f'), getCSSVar('--chart-color-4', '#9d162e'), getCSSVar('--chart-color-5', '#8a6996'), getCSSVar('--chart-color-6', '#a5d75f'), getCSSVar('--chart-color-7', '#003c73'), getCSSVar('--chart-color-8', '#d94415'), getCSSVar('--chart-color-9', '#643c20'), getCSSVar('--chart-color-10', '#c4dd88'), getCSSVar('--chart-color-11', '#f5dc69'), getCSSVar('--chart-color-12', '#424141'), ], // Background colors (semi-transparent versions) chartBgColors: [ getCSSVar('--chart-bg-1', 'rgba(0, 105, 62, 0.5)'), getCSSVar('--chart-bg-2', 'rgba(38, 122, 186, 0.5)'), getCSSVar('--chart-bg-3', 'rgba(255, 160, 15, 0.5)'), getCSSVar('--chart-bg-4', 'rgba(157, 22, 46, 0.5)'), getCSSVar('--chart-bg-5', 'rgba(138, 105, 150, 0.5)'), getCSSVar('--chart-bg-6', 'rgba(165, 215, 95, 0.5)'), ], // Semantic colors positive: getCSSVar('--chart-positive', '#00693e'), negative: getCSSVar('--chart-negative', '#9d162e'), neutral: getCSSVar('--chart-neutral', '#424141'), highlight: getCSSVar('--chart-highlight', '#ffa00f'), // Grid and axis colors gridLight: getCSSVar('--chart-grid-light', 'rgba(0, 105, 62, 0.1)'), gridMedium: getCSSVar('--chart-grid-medium', 'rgba(0, 105, 62, 0.15)'), gridDark: getCSSVar('--chart-grid-dark', 'rgba(0, 105, 62, 0.2)'), axisColor: getCSSVar('--chart-axis-color', '#0a2518'), // Font fontFamily: getCSSVar('--chart-font-family', "'Avenir LT Std', 'Avenir', 'Avenir Next', -apple-system, BlinkMacSystemFont, sans-serif"), }; } // Initialize palette (will be populated when DOM is ready) let CDL_PALETTE = null; // For convenience, expose primary chart colors array let CHART_COLORS = null; // ========================================================================== // FONT CONFIGURATION // Responsive font sizes based on typical Marp slide dimensions (1280x720) // ========================================================================== const FONT_CONFIG = { sizes: { title: 22, // Chart title subtitle: 18, // Subtitle legend: 16, // Legend labels axisTitle: 18, // Axis titles axisTicks: 16, // Axis tick labels tooltip: 14, // Tooltip text dataLabels: 14, // Data labels on charts }, weight: { normal: 400, medium: 500, bold: 600, }, }; // ========================================================================== // HELPER FUNCTIONS // ========================================================================== /** * Ensure palette is initialized */ function ensurePalette() { if (!CDL_PALETTE) { CDL_PALETTE = buildPaletteFromCSS(); CHART_COLORS = CDL_PALETTE.chartColors; } return CDL_PALETTE; } /** * Get color for a dataset at given index * Cycles through palette if more datasets than colors */ function getColor(index) { ensurePalette(); return CHART_COLORS[index % CHART_COLORS.length]; } /** * Get color with alpha transparency */ function getColorWithAlpha(color, alpha) { // Handle hex colors if (color.startsWith('#')) { const r = parseInt(color.slice(1, 3), 16); const g = parseInt(color.slice(3, 5), 16); const b = parseInt(color.slice(5, 7), 16); return `rgba(${r}, ${g}, ${b}, ${alpha})`; } // Handle rgba colors if (color.startsWith('rgba')) { return color.replace(/[\d.]+\)$/, `${alpha})`); } return color; } /** * Generate colors for all datasets in chart data * Automatically assigns colors if not specified */ function autoAssignColors(data, chartType) { if (!data || !data.datasets) return data; data.datasets.forEach((dataset, index) => { const baseColor = getColor(index); // Only assign colors if not already specified switch (chartType) { case 'bar': case 'horizontalBar': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'line': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.1); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 3; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 6; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } if (dataset.tension === undefined) { dataset.tension = 0.3; } break; case 'scatter': case 'bubble': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 15; } if (dataset.pointHoverRadius === undefined) { dataset.pointHoverRadius = 18; } break; case 'pie': case 'doughnut': case 'polarArea': // For pie charts, we need multiple colors for one dataset if (!dataset.backgroundColor) { const numItems = dataset.data ? dataset.data.length : 6; dataset.backgroundColor = []; for (let i = 0; i < numItems; i++) { dataset.backgroundColor.push(getColor(i)); } } if (!dataset.borderColor) { dataset.borderColor = '#d8d8d8'; // Slide background } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'radar': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.2); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 4; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } break; default: // Generic color assignment if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } } }); return data; } // ========================================================================== // CHART.JS GLOBAL DEFAULTS // ========================================================================== function applyGlobalDefaults() { if (typeof Chart === 'undefined') { console.warn('Chart.js not loaded. chart-defaults.js requires Chart.js to be loaded first.'); return false; } // Ensure palette is loaded from CSS const palette = ensurePalette(); // Font defaults Chart.defaults.font.family = palette.fontFamily; Chart.defaults.font.size = FONT_CONFIG.sizes.axisTicks; Chart.defaults.color = palette.textPrimary; // Responsive defaults Chart.defaults.responsive = true; Chart.defaults.maintainAspectRatio = false; // Animation (subtle) Chart.defaults.animation.duration = 400; // Plugin defaults // Legend Chart.defaults.plugins.legend.labels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, weight: FONT_CONFIG.weight.normal, }; Chart.defaults.plugins.legend.labels.color = palette.textPrimary; Chart.defaults.plugins.legend.labels.usePointStyle = true; Chart.defaults.plugins.legend.labels.padding = 20; // Title Chart.defaults.plugins.title.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.title, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.title.color = palette.textPrimary; // Tooltip Chart.defaults.plugins.tooltip.backgroundColor = palette.textPrimary; Chart.defaults.plugins.tooltip.titleFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.tooltip.bodyFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, }; Chart.defaults.plugins.tooltip.cornerRadius = 4; Chart.defaults.plugins.tooltip.padding = 10; // Scale defaults (for cartesian charts) // These need to be applied per-scale type const scaleDefaults = { grid: { color: palette.gridLight, lineWidth: 1, }, border: { color: palette.gridDark, width: 1, }, ticks: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }, color: palette.textPrimary, }, title: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTitle, weight: FONT_CONFIG.weight.normal, }, color: palette.textPrimary, }, }; // Apply scale defaults to linear scale if (Chart.defaults.scales && Chart.defaults.scales.linear) { if (Chart.defaults.scales.linear.grid) Object.assign(Chart.defaults.scales.linear.grid, scaleDefaults.grid); if (Chart.defaults.scales.linear.border) Object.assign(Chart.defaults.scales.linear.border, scaleDefaults.border); if (Chart.defaults.scales.linear.ticks) Object.assign(Chart.defaults.scales.linear.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.linear.title) Object.assign(Chart.defaults.scales.linear.title, scaleDefaults.title); } // Apply scale defaults to category scale if (Chart.defaults.scales && Chart.defaults.scales.category) { if (Chart.defaults.scales.category.grid) Object.assign(Chart.defaults.scales.category.grid, scaleDefaults.grid); if (Chart.defaults.scales.category.border) Object.assign(Chart.defaults.scales.category.border, scaleDefaults.border); if (Chart.defaults.scales.category.ticks) Object.assign(Chart.defaults.scales.category.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.category.title) Object.assign(Chart.defaults.scales.category.title, scaleDefaults.title); } // Apply scale defaults to logarithmic scale if (Chart.defaults.scales && Chart.defaults.scales.logarithmic) { if (Chart.defaults.scales.logarithmic.grid) Object.assign(Chart.defaults.scales.logarithmic.grid, scaleDefaults.grid); if (Chart.defaults.scales.logarithmic.border) Object.assign(Chart.defaults.scales.logarithmic.border, scaleDefaults.border); if (Chart.defaults.scales.logarithmic.ticks) Object.assign(Chart.defaults.scales.logarithmic.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.logarithmic.title) Object.assign(Chart.defaults.scales.logarithmic.title, scaleDefaults.title); } // Apply scale defaults to radial scale (for radar charts) if (Chart.defaults.scales && Chart.defaults.scales.radialLinear) { if (Chart.defaults.scales.radialLinear.grid) Chart.defaults.scales.radialLinear.grid.color = palette.gridLight; if (Chart.defaults.scales.radialLinear.angleLines) Chart.defaults.scales.radialLinear.angleLines.color = palette.gridMedium; if (Chart.defaults.scales.radialLinear.pointLabels) { Chart.defaults.scales.radialLinear.pointLabels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }; Chart.defaults.scales.radialLinear.pointLabels.color = palette.textPrimary; } } return true; } // ========================================================================== // CHART WRAPPER FOR AUTO-STYLING // ========================================================================== /** * Wrap the Chart constructor to automatically apply CDL styling */ function wrapChartConstructor() { if (typeof Chart === 'undefined') return; const OriginalChart = Chart; // Create a wrapper that auto-applies colors window.Chart = function(ctx, config) { // Auto-assign colors if not specified if (config && config.data) { config.data = autoAssignColors(config.data, config.type); } // Merge default options for specific chart types if (config && config.options) { config.options = applyChartTypeDefaults(config.type, config.options); } // Call original constructor return new OriginalChart(ctx, config); }; // Copy static properties and methods Object.setPrototypeOf(window.Chart, OriginalChart); Object.assign(window.Chart, OriginalChart); // Preserve the prototype chain window.Chart.prototype = OriginalChart.prototype; } /** * Apply chart-type specific defaults */ function applyChartTypeDefaults(chartType, userOptions) { const options = { ...userOptions }; switch (chartType) { case 'bar': case 'horizontalBar': // Bar chart defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; // Hide x-axis grid for cleaner look if (options.scales.x.grid === undefined) { options.scales.x.grid = { display: false }; } break; case 'line': // Line chart defaults if (!options.interaction) { options.interaction = { intersect: false, mode: 'index' }; } break; case 'pie': case 'doughnut': // Pie/doughnut defaults if (!options.plugins) options.plugins = {}; if (options.plugins.legend === undefined) { const palette = ensurePalette(); options.plugins.legend = { position: 'right', labels: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, }, color: palette.textPrimary, padding: 15, }, }; } break; case 'radar': // Radar chart defaults - keep as-is, scale defaults applied globally break; case 'scatter': case 'bubble': // Scatter/bubble defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; break; } return options; } // ========================================================================== // CONVENIENCE FUNCTIONS FOR USERS // Exposed on window.CDLChart for easy access // ========================================================================== window.CDLChart = { // Color palette access (getters to ensure lazy initialization) get colors() { return ensurePalette().chartColors; }, get palette() { return ensurePalette(); }, // Get specific color by index getColor: getColor, // Get color with transparency getColorWithAlpha: getColorWithAlpha, // Get array of colors for a specific count getColors: function(count) { ensurePalette(); const result = []; for (let i = 0; i < count; i++) { result.push(getColor(i)); } return result; }, // Font configuration fonts: FONT_CONFIG, // Quick chart creation helpers // These create minimal config that auto-applies all styling /** * Create a simple bar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ bar: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'bar', data: { labels: labels, datasets: [{ data: data }], }, options: { plugins: { legend: { display: false } }, ...options, }, }); }, /** * Create a simple line chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ line: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'line', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, fill: ds.fill !== undefined ? ds.fill : true, })), }, options: options, }); }, /** * Create a simple pie chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ pie: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'pie', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a simple scatter chart * @param {string} canvasId - Canvas element ID * @param {Array} datasets - Array of {label, data: [{x, y}]} objects * @param {object} options - Optional overrides */ scatter: function(canvasId, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'scatter', data: { datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, /** * Create a doughnut chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ doughnut: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'doughnut', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a radar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ radar: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'radar', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, }; // ========================================================================== // INITIALIZATION // ========================================================================== function initialize() { // Wait for Chart.js to be available if (typeof Chart !== 'undefined') { applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully.'); return true; } else { // Chart.js not yet loaded - wait and retry let retries = 0; const maxRetries = 50; // 5 seconds max wait const checkInterval = setInterval(function() { retries++; if (typeof Chart !== 'undefined') { clearInterval(checkInterval); applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully (after waiting for Chart.js).'); } else if (retries >= maxRetries) { clearInterval(checkInterval); console.warn('Chart.js not found after waiting. CDL Chart defaults not applied.'); } }, 100); return false; } } // Initialize IMMEDIATELY - this must run BEFORE any chart creation scripts // Chart.js CDN should be loaded before this script initialize(); })();

Lecture 16: Training transformer models

PSYC 51.17: Models of language and communication

Jeremy R. Manning
Dartmouth College
Winter 2026

Learning objectives

  1. Explain what it means to "train" a transformer model
  2. Describe how training data is constructed from raw text
  3. Understand cross-entropy loss and perplexity as training metrics
  4. Describe the training loop and key optimization techniques
  5. Distinguish between foundation models, instruct-tuned models, and fine-tuned models

Where we left off

Last time we saw the transformer as a function:

Transformer(X,θ)Y\text{Transformer}(X, \theta) \rightarrow Y

We traced data from raw text through tokenization, embedding, attention, and feed-forward layers. But we treated the parameters θ\theta as given.

Where do the parameters θ\theta come from? How does the model learn to predict the next token?

What does "training" mean?

Training is the process of finding parameters θ\theta that make the model's predictions match the data. Given a large corpus of text, we want:

θ=argminθL(θ)\theta^* = \arg\min_\theta \mathcal{L}(\theta)

where L\mathcal{L} is a loss function that measures how wrong the model's predictions are.

Imagine the model is a student learning to finish sentences. We show it millions of sentences, and every time it guesses the next word wrong, we nudge its parameters to make a better guess next time.

The language modeling objective

Given a sequence of tokens x1,x2,,xt1x_1, x_2, \ldots, x_{t-1}, the model predicts a probability distribution over the vocabulary for the next token xtx_t:

P(xtx1,x2,,xt1;θ)P(x_t \mid x_1, x_2, \ldots, x_{t-1}; \theta)

The goal is to maximize the probability of the actual next token across the entire training corpus.

No human labels needed! The "label" is simply the next word in the text. This is why we can train on virtually unlimited data from the internet.

From raw text to training data

Training data is built from raw text in three steps:

  1. Tokenize every document in the corpus independently
  2. Concatenate all token sequences into one long stream (optionally separated by <EOS> tokens)
  3. Chunk the stream into fixed-length blocks of block_size tokens (e.g., 1024 or 2048)

Each chunk becomes one training example. Leftover tokens shorter than block_size are dropped.

Padding wastes compute — every [PAD] token is a wasted FLOP. Concatenation ensures every token in every batch is a real training signal. This is the standard approach used by HuggingFace, GPT-2/3, LLaMA, and most modern language models.

The causal attention mask

Within each chunk, the causal attention mask (a lower-triangular matrix) ensures position tt can only attend to positions 1,2,,t1, 2, \ldots, t. This means every position simultaneously predicts its next token — one forward pass through a chunk of length LL produces LL training signals.

Labels are simply the input shifted right by one position. The model handles this internally during the forward pass.

Think of it like a classroom where every student takes the same test at the same time, but each student can only see the questions before theirs. Student 1 sees nothing and guesses the first word. Student 2 sees word 1 and predicts word 2. Student LL sees all previous words and predicts the last. All LL students learn simultaneously from a single test.

Cross-entropy loss

Cross-entropy loss measures how different the model's predicted distribution y^\hat{y} is from the true distribution yy (a one-hot vector for the correct token):

L=i=1Vyilog(y^i)\mathcal{L} = -\sum_{i=1}^{V} y_i \log(\hat{y}_i)

Since yy is one-hot (only the correct token has yi=1y_i = 1), this simplifies to:

L=log(y^c)\mathcal{L} = -\log(\hat{y}_c)

where cc is the index of the correct next token.

Understanding cross-entropy

  • If the model assigns probability 1.0 to the correct token: log(1.0)=0-\log(1.0) = 0 (no loss!)
  • If the model assigns probability 0.5: log(0.5)=0.69-\log(0.5) = 0.69
  • If the model assigns probability 0.01: log(0.01)=4.6-\log(0.01) = 4.6 (high loss!)

The loss penalizes confident wrong predictions heavily and rewards confident correct predictions.


1Vocabulary: [the, cat, sat, on, mat]  (V=5)
2True next token: "sat" → y = [0, 0, 1, 0, 0]
3
4Model prediction: ŷ = [0.1, 0.2, 0.5, 0.15, 0.05]
5Loss = -log(0.5) = 0.69
6
7Better prediction: ŷ = [0.05, 0.05, 0.85, 0.03, 0.02]
8Loss = -log(0.85) = 0.16  ← much lower!

Perplexity

Perplexity is the exponential of the average cross-entropy loss across a sequence:

PPL=eL=e1Nt=1NlogP(xtx<t)\text{PPL} = e^{\mathcal{L}} = e^{-\frac{1}{N}\sum_{t=1}^{N} \log P(x_t \mid x_{<t})}

A perplexity of kk means the model is, on average, as uncertain as if it were choosing uniformly among kk options.

  • Model comparison: Lower PPL = better model. GPT-2 Small (PPL \approx 27) vs. GPT-2 XL (PPL \approx 17) vs. GPT-3 (PPL \approx 11)
  • Training monitoring: Plot PPL over training steps; it should decrease. Spikes indicate instability.
  • Evaluation: Standard metric on held-out test sets (e.g., WikiText-103)

Only measures token-level prediction accuracy — doesn't capture fluency, coherence, or factual correctness. A model with low PPL can still hallucinate or generate toxic content. Can't compare models with different vocabularies (different PPL scales), and values are domain-dependent.

The training loop

  1. Forward pass: feed a batch of text through the model to get predictions
  2. Compute loss: compare predictions to actual next tokens using cross-entropy
  3. Backward pass: compute gradients θL\nabla_\theta \mathcal{L} via backpropagation
  4. Update parameters: adjust θ\theta to reduce the loss

θθηθL\theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L}

where η\eta is the learning rate — how big a step we take.

GPT-3 has 175 billion parameters. Each training step updates all of them. Training took ~$4.6 million in compute costs and processed ~300 billion tokens. GPT-4 likely cost over $100 million to train, and GPT-5 (trillions of parameters) was even more expensive to train. Energy consumption and environmental impact are significant concerns at this scale.

Optimization and regularization

  • Stochastic gradient descent (SGD): Use random mini-batches instead of the entire dataset. Faster and noisier, but works well in practice.
  • AdamW optimizer: Gives each parameter its own adaptive learning rate based on gradient history. Adds momentum (smooths updates) and weight decay (prevents overfitting). The standard choice for transformers.
  • Learning rate scheduling: Start with a warmup phase (gradually increase LR), then decay with a cosine or linear schedule. Prevents instability early in training.
  • Gradient clipping: Cap gradient magnitude to prevent "exploding gradients" from destroying training progress. Typical max norm = 1.0.

Loshchilov & Hutter (2019, ICLR) "Decoupled Weight Decay Regularization" — The AdamW optimizer paper.

Vaswani et al. (2017, NeurIPS) "Attention Is All You Need" — Introduced warmup scheduling for transformers.

Scaling laws

Language model performance follows predictable power laws: smooth, straight lines on log-log plots when you increase model size, dataset size, or compute budget.

  • Kaplan et al. (2020): Discovered these relationships across seven orders of magnitude
  • Hoffmann et al. (2022, "Chinchilla"): Showed optimal training balances model size and data — a smaller model trained on more data can match a larger undertrained model (Chinchilla 70B matched GPT-3 175B)

Key takeaway: you can predict model performance before training by running small-scale experiments first.

Kaplan et al. (2020, arXiv) "Scaling Laws for Neural Language Models"

Hoffmann et al. (2022, arXiv) "Training Compute-Optimal Large Language Models" (Chinchilla)

Training with HuggingFace


1from transformers import AutoModelForCausalLM, AutoTokenizer
2from transformers import Trainer, TrainingArguments
3from datasets import load_dataset
4
5model = AutoModelForCausalLM.from_pretrained("gpt2")
6tokenizer = AutoTokenizer.from_pretrained("gpt2")
7
8dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
9def tokenize(examples):
10    return tokenizer(examples["text"], truncation=True, max_length=512)
11tokenized = dataset.map(tokenize, batched=True)
12
13args = TrainingArguments(
14    output_dir="./results", num_train_epochs=3,
15    per_device_train_batch_size=8, learning_rate=5e-5,
16    warmup_steps=500, weight_decay=0.01,
17)
18trainer = Trainer(model=model, args=args, train_dataset=tokenized["train"])
19trainer.train()

From pre-training to deployment

  1. Pre-training \rightarrow Foundation model: Train on a massive general text corpus (books, web, code). The model learns language patterns, world knowledge, and reasoning. Expensive ($millions), done once. Examples: GPT-3 base, LLaMA, Mistral.

  2. Instruction tuning + RLHF \rightarrow Instruct model: Fine-tune the foundation model on instruction-following data and human preference feedback. The model learns to be helpful, harmless, and honest. This is what makes ChatGPT different from raw GPT-4. Examples: ChatGPT, Claude, Gemini.

  3. Task-specific fine-tuning \rightarrow Specialized model: Further fine-tune on domain data for a specific application. Cheap, fast, and highly effective. Examples: medical diagnosis, legal analysis, code generation for a specific codebase.

Why instruction tuning matters

A foundation model has vast knowledge but no "manners" — it will complete any text prompt, including harmful or nonsensical ones. Instruction tuning teaches the model to follow instructions, answer questions helpfully, and refuse harmful requests.

RLHF (Reinforcement Learning from Human Feedback): Humans rank model outputs, and the model is trained to prefer higher-ranked responses. This is the key technique behind ChatGPT, Claude, and other assistants.

Ouyang et al. (2022, NeurIPS) "Training language models to follow instructions with human feedback" — The InstructGPT paper that launched the instruction-tuning revolution.

Bai et al. (2022, arXiv) "Constitutional AI: Harmlessness from AI Feedback" — An alternative to human feedback.

Task-specific fine-tuning

Take a pre-trained (or instruct-tuned) model and train it further on task-specific data:

  • Only needs small datasets (hundreds to thousands of examples)
  • Uses tiny learning rates (10-100x smaller than pre-training) to preserve existing knowledge
  • Training takes minutes to hours, not weeks

The pre-trained model already understands language, facts, and reasoning. Fine-tuning just teaches it the format and focus of your task — like an expert learning a new specialty. Fine-tuning on 1,000 labeled examples typically outperforms training from scratch on 100,000 examples.

Fine-tuning with HuggingFace


1from transformers import AutoModelForSequenceClassification
2from transformers import Trainer, TrainingArguments
3from datasets import load_dataset
4
5model = AutoModelForSequenceClassification.from_pretrained(
6    "bert-base-uncased", num_labels=2
7)
8dataset = load_dataset("imdb")
9
10args = TrainingArguments(
11    output_dir="./sentiment-model", num_train_epochs=3,
12    per_device_train_batch_size=16, learning_rate=2e-5,
13    warmup_ratio=0.1, weight_decay=0.01,
14    evaluation_strategy="epoch",
15)
16trainer = Trainer(
17    model=model, args=args,
18    train_dataset=dataset["train"], eval_dataset=dataset["test"],
19)
20trainer.train()

Practical training advice

  1. Start small: Test your pipeline with a small model before scaling up
  2. Learning rate: Fine-tuning uses 10-100x smaller learning rates than pre-training (typically 1e-5 to 5e-5)
  3. Epochs: Fine-tuning needs only 2-5 epochs; pre-training typically makes 1-2 passes over the data since the dataset is so large
  4. Batch size: Larger batches give more stable gradients but require more memory. Use gradient accumulation if your GPU is too small
  5. Evaluation: Always hold out a validation set and monitor for overfitting

Questions?

📧 Email
💬 Discord

Retrieval Augmented Generation (RAG): giving language models access to external knowledge!

split: 22

split: 24