* 2. Include this script: * 3. Create charts with minimal configuration - colors are auto-applied! */ (function() { 'use strict'; // ========================================================================== // READ COLORS FROM CSS CUSTOM PROPERTIES // This ensures chart colors stay in sync with the theme // ========================================================================== /** * Get a CSS custom property value from :root */ function getCSSVar(name, fallback = '') { if (typeof getComputedStyle === 'undefined') return fallback; const value = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); return value || fallback; } /** * Build palette from CSS custom properties (with fallbacks) */ function buildPaletteFromCSS() { return { // Primary brand colors dartmouthGreen: getCSSVar('--dartmouth-green', '#00693e'), textPrimary: getCSSVar('--text-primary', '#0a2518'), textSecondary: getCSSVar('--text-secondary', '#0a3d23'), // Chart colors (from CSS --chart-color-N variables) chartColors: [ getCSSVar('--chart-color-1', '#00693e'), getCSSVar('--chart-color-2', '#267aba'), getCSSVar('--chart-color-3', '#ffa00f'), getCSSVar('--chart-color-4', '#9d162e'), getCSSVar('--chart-color-5', '#8a6996'), getCSSVar('--chart-color-6', '#a5d75f'), getCSSVar('--chart-color-7', '#003c73'), getCSSVar('--chart-color-8', '#d94415'), getCSSVar('--chart-color-9', '#643c20'), getCSSVar('--chart-color-10', '#c4dd88'), getCSSVar('--chart-color-11', '#f5dc69'), getCSSVar('--chart-color-12', '#424141'), ], // Background colors (semi-transparent versions) chartBgColors: [ getCSSVar('--chart-bg-1', 'rgba(0, 105, 62, 0.5)'), getCSSVar('--chart-bg-2', 'rgba(38, 122, 186, 0.5)'), getCSSVar('--chart-bg-3', 'rgba(255, 160, 15, 0.5)'), getCSSVar('--chart-bg-4', 'rgba(157, 22, 46, 0.5)'), getCSSVar('--chart-bg-5', 'rgba(138, 105, 150, 0.5)'), getCSSVar('--chart-bg-6', 'rgba(165, 215, 95, 0.5)'), ], // Semantic colors positive: getCSSVar('--chart-positive', '#00693e'), negative: getCSSVar('--chart-negative', '#9d162e'), neutral: getCSSVar('--chart-neutral', '#424141'), highlight: getCSSVar('--chart-highlight', '#ffa00f'), // Grid and axis colors gridLight: getCSSVar('--chart-grid-light', 'rgba(0, 105, 62, 0.1)'), gridMedium: getCSSVar('--chart-grid-medium', 'rgba(0, 105, 62, 0.15)'), gridDark: getCSSVar('--chart-grid-dark', 'rgba(0, 105, 62, 0.2)'), axisColor: getCSSVar('--chart-axis-color', '#0a2518'), // Font fontFamily: getCSSVar('--chart-font-family', "'Avenir LT Std', 'Avenir', 'Avenir Next', -apple-system, BlinkMacSystemFont, sans-serif"), }; } // Initialize palette (will be populated when DOM is ready) let CDL_PALETTE = null; // For convenience, expose primary chart colors array let CHART_COLORS = null; // ========================================================================== // FONT CONFIGURATION // Responsive font sizes based on typical Marp slide dimensions (1280x720) // ========================================================================== const FONT_CONFIG = { sizes: { title: 22, // Chart title subtitle: 18, // Subtitle legend: 16, // Legend labels axisTitle: 18, // Axis titles axisTicks: 16, // Axis tick labels tooltip: 14, // Tooltip text dataLabels: 14, // Data labels on charts }, weight: { normal: 400, medium: 500, bold: 600, }, }; // ========================================================================== // HELPER FUNCTIONS // ========================================================================== /** * Ensure palette is initialized */ function ensurePalette() { if (!CDL_PALETTE) { CDL_PALETTE = buildPaletteFromCSS(); CHART_COLORS = CDL_PALETTE.chartColors; } return CDL_PALETTE; } /** * Get color for a dataset at given index * Cycles through palette if more datasets than colors */ function getColor(index) { ensurePalette(); return CHART_COLORS[index % CHART_COLORS.length]; } /** * Get color with alpha transparency */ function getColorWithAlpha(color, alpha) { // Handle hex colors if (color.startsWith('#')) { const r = parseInt(color.slice(1, 3), 16); const g = parseInt(color.slice(3, 5), 16); const b = parseInt(color.slice(5, 7), 16); return `rgba(${r}, ${g}, ${b}, ${alpha})`; } // Handle rgba colors if (color.startsWith('rgba')) { return color.replace(/[\d.]+\)$/, `${alpha})`); } return color; } /** * Generate colors for all datasets in chart data * Automatically assigns colors if not specified */ function autoAssignColors(data, chartType) { if (!data || !data.datasets) return data; data.datasets.forEach((dataset, index) => { const baseColor = getColor(index); // Only assign colors if not already specified switch (chartType) { case 'bar': case 'horizontalBar': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'line': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.1); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 3; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 6; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } if (dataset.tension === undefined) { dataset.tension = 0.3; } break; case 'scatter': case 'bubble': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 15; } if (dataset.pointHoverRadius === undefined) { dataset.pointHoverRadius = 18; } break; case 'pie': case 'doughnut': case 'polarArea': // For pie charts, we need multiple colors for one dataset if (!dataset.backgroundColor) { const numItems = dataset.data ? dataset.data.length : 6; dataset.backgroundColor = []; for (let i = 0; i < numItems; i++) { dataset.backgroundColor.push(getColor(i)); } } if (!dataset.borderColor) { dataset.borderColor = '#d8d8d8'; // Slide background } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'radar': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.2); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 4; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } break; default: // Generic color assignment if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } } }); return data; } // ========================================================================== // CHART.JS GLOBAL DEFAULTS // ========================================================================== function applyGlobalDefaults() { if (typeof Chart === 'undefined') { console.warn('Chart.js not loaded. chart-defaults.js requires Chart.js to be loaded first.'); return false; } // Ensure palette is loaded from CSS const palette = ensurePalette(); // Font defaults Chart.defaults.font.family = palette.fontFamily; Chart.defaults.font.size = FONT_CONFIG.sizes.axisTicks; Chart.defaults.color = palette.textPrimary; // Responsive defaults Chart.defaults.responsive = true; Chart.defaults.maintainAspectRatio = false; // Animation (subtle) Chart.defaults.animation.duration = 400; // Plugin defaults // Legend Chart.defaults.plugins.legend.labels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, weight: FONT_CONFIG.weight.normal, }; Chart.defaults.plugins.legend.labels.color = palette.textPrimary; Chart.defaults.plugins.legend.labels.usePointStyle = true; Chart.defaults.plugins.legend.labels.padding = 20; // Title Chart.defaults.plugins.title.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.title, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.title.color = palette.textPrimary; // Tooltip Chart.defaults.plugins.tooltip.backgroundColor = palette.textPrimary; Chart.defaults.plugins.tooltip.titleFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.tooltip.bodyFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, }; Chart.defaults.plugins.tooltip.cornerRadius = 4; Chart.defaults.plugins.tooltip.padding = 10; // Scale defaults (for cartesian charts) // These need to be applied per-scale type const scaleDefaults = { grid: { color: palette.gridLight, lineWidth: 1, }, border: { color: palette.gridDark, width: 1, }, ticks: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }, color: palette.textPrimary, }, title: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTitle, weight: FONT_CONFIG.weight.normal, }, color: palette.textPrimary, }, }; // Apply scale defaults to linear scale if (Chart.defaults.scales && Chart.defaults.scales.linear) { if (Chart.defaults.scales.linear.grid) Object.assign(Chart.defaults.scales.linear.grid, scaleDefaults.grid); if (Chart.defaults.scales.linear.border) Object.assign(Chart.defaults.scales.linear.border, scaleDefaults.border); if (Chart.defaults.scales.linear.ticks) Object.assign(Chart.defaults.scales.linear.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.linear.title) Object.assign(Chart.defaults.scales.linear.title, scaleDefaults.title); } // Apply scale defaults to category scale if (Chart.defaults.scales && Chart.defaults.scales.category) { if (Chart.defaults.scales.category.grid) Object.assign(Chart.defaults.scales.category.grid, scaleDefaults.grid); if (Chart.defaults.scales.category.border) Object.assign(Chart.defaults.scales.category.border, scaleDefaults.border); if (Chart.defaults.scales.category.ticks) Object.assign(Chart.defaults.scales.category.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.category.title) Object.assign(Chart.defaults.scales.category.title, scaleDefaults.title); } // Apply scale defaults to logarithmic scale if (Chart.defaults.scales && Chart.defaults.scales.logarithmic) { if (Chart.defaults.scales.logarithmic.grid) Object.assign(Chart.defaults.scales.logarithmic.grid, scaleDefaults.grid); if (Chart.defaults.scales.logarithmic.border) Object.assign(Chart.defaults.scales.logarithmic.border, scaleDefaults.border); if (Chart.defaults.scales.logarithmic.ticks) Object.assign(Chart.defaults.scales.logarithmic.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.logarithmic.title) Object.assign(Chart.defaults.scales.logarithmic.title, scaleDefaults.title); } // Apply scale defaults to radial scale (for radar charts) if (Chart.defaults.scales && Chart.defaults.scales.radialLinear) { if (Chart.defaults.scales.radialLinear.grid) Chart.defaults.scales.radialLinear.grid.color = palette.gridLight; if (Chart.defaults.scales.radialLinear.angleLines) Chart.defaults.scales.radialLinear.angleLines.color = palette.gridMedium; if (Chart.defaults.scales.radialLinear.pointLabels) { Chart.defaults.scales.radialLinear.pointLabels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }; Chart.defaults.scales.radialLinear.pointLabels.color = palette.textPrimary; } } return true; } // ========================================================================== // CHART WRAPPER FOR AUTO-STYLING // ========================================================================== /** * Wrap the Chart constructor to automatically apply CDL styling */ function wrapChartConstructor() { if (typeof Chart === 'undefined') return; const OriginalChart = Chart; // Create a wrapper that auto-applies colors window.Chart = function(ctx, config) { // Auto-assign colors if not specified if (config && config.data) { config.data = autoAssignColors(config.data, config.type); } // Merge default options for specific chart types if (config && config.options) { config.options = applyChartTypeDefaults(config.type, config.options); } // Call original constructor return new OriginalChart(ctx, config); }; // Copy static properties and methods Object.setPrototypeOf(window.Chart, OriginalChart); Object.assign(window.Chart, OriginalChart); // Preserve the prototype chain window.Chart.prototype = OriginalChart.prototype; } /** * Apply chart-type specific defaults */ function applyChartTypeDefaults(chartType, userOptions) { const options = { ...userOptions }; switch (chartType) { case 'bar': case 'horizontalBar': // Bar chart defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; // Hide x-axis grid for cleaner look if (options.scales.x.grid === undefined) { options.scales.x.grid = { display: false }; } break; case 'line': // Line chart defaults if (!options.interaction) { options.interaction = { intersect: false, mode: 'index' }; } break; case 'pie': case 'doughnut': // Pie/doughnut defaults if (!options.plugins) options.plugins = {}; if (options.plugins.legend === undefined) { const palette = ensurePalette(); options.plugins.legend = { position: 'right', labels: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, }, color: palette.textPrimary, padding: 15, }, }; } break; case 'radar': // Radar chart defaults - keep as-is, scale defaults applied globally break; case 'scatter': case 'bubble': // Scatter/bubble defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; break; } return options; } // ========================================================================== // CONVENIENCE FUNCTIONS FOR USERS // Exposed on window.CDLChart for easy access // ========================================================================== window.CDLChart = { // Color palette access (getters to ensure lazy initialization) get colors() { return ensurePalette().chartColors; }, get palette() { return ensurePalette(); }, // Get specific color by index getColor: getColor, // Get color with transparency getColorWithAlpha: getColorWithAlpha, // Get array of colors for a specific count getColors: function(count) { ensurePalette(); const result = []; for (let i = 0; i < count; i++) { result.push(getColor(i)); } return result; }, // Font configuration fonts: FONT_CONFIG, // Quick chart creation helpers // These create minimal config that auto-applies all styling /** * Create a simple bar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ bar: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'bar', data: { labels: labels, datasets: [{ data: data }], }, options: { plugins: { legend: { display: false } }, ...options, }, }); }, /** * Create a simple line chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ line: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'line', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, fill: ds.fill !== undefined ? ds.fill : true, })), }, options: options, }); }, /** * Create a simple pie chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ pie: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'pie', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a simple scatter chart * @param {string} canvasId - Canvas element ID * @param {Array} datasets - Array of {label, data: [{x, y}]} objects * @param {object} options - Optional overrides */ scatter: function(canvasId, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'scatter', data: { datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, /** * Create a doughnut chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ doughnut: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'doughnut', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a radar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ radar: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'radar', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, }; // ========================================================================== // INITIALIZATION // ========================================================================== function initialize() { // Wait for Chart.js to be available if (typeof Chart !== 'undefined') { applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully.'); return true; } else { // Chart.js not yet loaded - wait and retry let retries = 0; const maxRetries = 50; // 5 seconds max wait const checkInterval = setInterval(function() { retries++; if (typeof Chart !== 'undefined') { clearInterval(checkInterval); applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully (after waiting for Chart.js).'); } else if (retries >= maxRetries) { clearInterval(checkInterval); console.warn('Chart.js not found after waiting. CDL Chart defaults not applied.'); } }, 100); return false; } } // Initialize IMMEDIATELY - this must run BEFORE any chart creation scripts // Chart.js CDN should be loaded before this script initialize(); })();
**Components we'll cover:**
1.
Tokenization (Byte-Pair Encoding)
Embeddings (token + position)
Masked Multi-Head Attention
Transformer Decoder Block
Language Model Head
Training Loop
Text Generation
We'll build a "nano-GPT"—small enough to train on a laptop, but with the same architecture as the real thing!
1class TransformerBlock(nn.Module):
2 def __init__(self, d_model, n_heads, dropout):
3 super().__init__()
4 self.attention = MultiHeadAttention(d_model, n_heads, dropout)
5 self.feed_forward = FeedForward(d_model, dropout)
6
7 # Layer normalization (applied before sub-layers in GPT)
8 self.ln1 = nn.LayerNorm(d_model)
9 self.ln2 = nn.LayerNorm(d_model)
10
11 def forward(self, x, mask):
12 # Pre-norm architecture (used in GPT)
13 # Attention with residual connection
14 x = x + self.attention(self.ln1(x), mask)
15
16 # Feed-forward with residual connection
17 x = x + self.feed_forward(self.ln2(x))
18
19 return x
Note: GPT uses pre-norm (LayerNorm before sub-layers), while original Transformer used post-norm.
1class GPT(nn.Module):
2 def __init__(self, vocab_size, d_model, n_layers, n_heads, max_seq_len, dropout):
3 super().__init__()
4 self.max_seq_len = max_seq_len
5
6 # Embeddings
7 self.embeddings = Embeddings(vocab_size, d_model, max_seq_len, dropout)
8
9 # Transformer blocks
10 self.blocks = nn.ModuleList([
11 TransformerBlock(d_model, n_heads, dropout)
12 for _ in range(n_layers)
13 ])
14
15 # Final layer norm
16 self.ln_f = nn.LayerNorm(d_model)
17
18 # Language model head (projects to vocabulary)
19 self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
20
21 # Initialize weights
22 self.apply(self._init_weights)
23
24 def _init_weights(self, module):
25 if isinstance(module, nn.Linear):
26 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
27 if module.bias is not None:
28 torch.nn.init.zeros_(module.bias)
29 elif isinstance(module, nn.Embedding):
30 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
1def forward(self, x, targets=None):
2 # x shape: (batch_size, seq_len)
3 seq_len = x.size(1)
4
5 # Create causal mask
6 mask = create_causal_mask(seq_len, x.device)
7
8 # Embeddings
9 x = self.embeddings(x) # (batch, seq_len, d_model)
10
11 # Apply transformer blocks
12 for block in self.blocks:
13 x = block(x, mask)
14
15 # Final layer norm
16 x = self.ln_f(x)
17
18 # Project to vocabulary
19 logits = self.lm_head(x) # (batch, seq_len, vocab_size)
20
21 # Compute loss if targets provided
22 loss = None
23 if targets is not None:
24 # Flatten for cross-entropy
25 loss = F.cross_entropy(
26 logits.view(-1, logits.size(-1)),
27 targets.view(-1)
28 )
29
30 return logits, loss
1# Initialize model
2model = GPT(
3 vocab_size=config['vocab_size'],
4 d_model=config['d_model'],
5 n_layers=config['n_layers'],
6 n_heads=config['n_heads'],
7 max_seq_len=config['max_seq_len'],
8 dropout=config['dropout']
9).to(device)
10
11# Optimizer (AdamW is used for GPT)
12optimizer = torch.optim.AdamW(
13 model.parameters(),
14 lr=config['learning_rate'],
15 betas=(0.9, 0.95),
16 weight_decay=0.1
17)
18
19# Training loop
20model.train()
21for epoch in range(config['num_epochs']):
22 total_loss = 0
23 for batch_idx, (x, y) in enumerate(dataloader):
24 x, y = x.to(device), y.to(device)
25
26 # Forward pass
27 logits, loss = model(x, targets=y)
28
29 # Backward pass
30 optimizer.zero_grad()
31 loss.backward()
32 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
33 optimizer.step()
34
35 total_loss += loss.item()
36
37 avg_loss = total_loss / len(dataloader)
38 print(f"Epoch {epoch+1}/{config['num_epochs']}, Loss: {avg_loss:.4f}")
**Best practices for training GPT:**
1. **Gradient clipping**
- Prevents exploding gradients
1@torch.no_grad()
2def generate_greedy(model, tokenizer, prompt, max_new_tokens=50):
3 model.eval()
4
5 # Encode prompt
6 tokens = tokenizer.encode(prompt)
7 x = torch.tensor([tokens], dtype=torch.long, device=device)
8
9 for _ in range(max_new_tokens):
10 # Get predictions (crop to max_seq_len if needed)
11 x_crop = x[:, -model.max_seq_len:]
12 logits, _ = model(x_crop)
13
14 # Focus on last token's predictions
15 logits = logits[:, -1, :] # (batch, vocab_size)
16
17 # Get token with highest probability
18 next_token = torch.argmax(logits, dim=-1, keepdim=True)
19
20 # Append to sequence
21 x = torch.cat([x, next_token], dim=1)
22
23 # Stop if we generate end-of-sequence token
24 if next_token.item() == tokenizer.eot_token:
25 break
26
27 # Decode and return
28 generated_text = tokenizer.decode(x[0].tolist())
29 return generated_text
30
31# Example usage
32prompt = "Once upon a time"
33generated = generate_greedy(model, tokenizer, prompt, max_new_tokens=100)
34print(generated)
**Different ways to sample next token:**
1. **Greedy Decoding**
- Always pick most likely token
1def sample_next_token(logits, temperature=1.0, top_k=None, top_p=None):
2 """
3 Sample next token from logits with various strategies.
4
5 Args:
6 logits: (vocab_size,) unnormalized log probabilities
7 temperature: Temperature for sampling
8 top_k: If set, only sample from top k tokens
9 top_p: If set, only sample from nucleus (top-p)
10 """
11 # Apply temperature
12 logits = logits / temperature
13
14 # Top-k filtering
15 if top_k is not None:
16 top_k = min(top_k, logits.size(-1))
17 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
18 logits[indices_to_remove] = float('-inf')
19
20 # Nucleus (top-p) filtering
21 if top_p is not None:
22 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
23 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
24
25 # Remove tokens with cumulative probability above threshold
26 sorted_indices_to_remove = cumulative_probs > top_p
27 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
28 sorted_indices_to_remove[..., 0] = 0
29
30 indices_to_remove = sorted_indices[sorted_indices_to_remove]
31 logits[indices_to_remove] = float('-inf')
32
33 # Sample from distribution
34 probs = F.softmax(logits, dim=-1)
35 next_token = torch.multinomial(probs, num_samples=1)
36
37 return next_token
1@torch.no_grad()
2def generate(model, tokenizer, prompt, max_new_tokens=100,
3 temperature=1.0, top_k=40, top_p=0.9):
4 model.eval()
5
6 tokens = tokenizer.encode(prompt)
7 x = torch.tensor([tokens], dtype=torch.long, device=device)
8
9 for _ in range(max_new_tokens):
10 x_crop = x[:, -model.max_seq_len:]
11 logits, _ = model(x_crop)
12 logits = logits[:, -1, :] # Last token
13
14 # Sample next token
15 next_token = sample_next_token(
16 logits[0],
17 temperature=temperature,
18 top_k=top_k,
19 top_p=top_p
20 )
21
22 x = torch.cat([x, next_token.unsqueeze(0)], dim=1)
23
24 if next_token.item() == tokenizer.eot_token:
25 break
26
27 return tokenizer.decode(x[0].tolist())
28
29# Creative generation
30text = generate(model, tokenizer, "The AI revolution",
31 temperature=0.8, top_p=0.9)
32print(text)
1 10M -> 100M -> 1B -> 10B -> 100B+
**Practical model sizes for different use cases:**
- **10M-100M**: Learning/experimentation, simple tasks
**Training a 125M parameter GPT:**
| Training Time | 1-2 days (single GPU) |
| --- | --- |
| Training Data | $\sim$10-100 GB text |
| Total Compute | $\sim$100 GPU-hours |
**Scaling up to GPT-3 (175B):**
- 1000x more parameters
~10,000x more compute needed
Requires distributed training across many GPUs
Estimated $4.6M in compute costs
This is why pre-trained models are so valuable—you don't have to train from scratch!
**Problems you might encounter:**
1. **Loss not decreasing**
- Check learning rate (try 1e-4 to 3e-4)
**Ways to enhance your GPT implementation:**
1. **Architectural improvements**
- Rotary Position Embeddings (RoPE)
**Recommended resources:**
- **Andrej Karpathy's nanoGPT**
Clean, minimal GPT implementation
\item **Andrej Karpathy's "Let's build GPT" video**
- Excellent step-by-step tutorial
\item **HuggingFace Transformers**
- Production-ready implementations
\item **PyTorch Documentation**
- Official tutorials and guides
<div class="callout warning">
Implement and train a small GPT model on your own text corpus!
**Steps:**
1. Choose a dataset (Shakespeare, Wikipedia, your own text)
Set up the data pipeline
Initialize the model (start small: 6 layers, 384 d_model)
Train for 10-20 epochs
Experiment with generation
Try different sampling strategies
Starter code available:
1. **GPT is conceptually simple**
- Stack of transformer decoder blocks
<div class="callout info">
<div class="callout info">
**Week 8: No Classes - Instructor Away**
**Use this week to:**
- Complete the GPT implementation exercise
Catch up on readings
Work on assignments
Experiment with different architectures
Week 9: RAG & Mixture of Experts
Mixture of Experts architectures
Ethics, Bias, and Safety in LLMs
Questions?
Happy Coding!
TODO: Add manual table of contents or navigation