* 2. Include this script: * 3. Create charts with minimal configuration - colors are auto-applied! */ (function() { 'use strict'; // ========================================================================== // READ COLORS FROM CSS CUSTOM PROPERTIES // This ensures chart colors stay in sync with the theme // ========================================================================== /** * Get a CSS custom property value from :root */ function getCSSVar(name, fallback = '') { if (typeof getComputedStyle === 'undefined') return fallback; const value = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); return value || fallback; } /** * Build palette from CSS custom properties (with fallbacks) */ function buildPaletteFromCSS() { return { // Primary brand colors dartmouthGreen: getCSSVar('--dartmouth-green', '#00693e'), textPrimary: getCSSVar('--text-primary', '#0a2518'), textSecondary: getCSSVar('--text-secondary', '#0a3d23'), // Chart colors (from CSS --chart-color-N variables) chartColors: [ getCSSVar('--chart-color-1', '#00693e'), getCSSVar('--chart-color-2', '#267aba'), getCSSVar('--chart-color-3', '#ffa00f'), getCSSVar('--chart-color-4', '#9d162e'), getCSSVar('--chart-color-5', '#8a6996'), getCSSVar('--chart-color-6', '#a5d75f'), getCSSVar('--chart-color-7', '#003c73'), getCSSVar('--chart-color-8', '#d94415'), getCSSVar('--chart-color-9', '#643c20'), getCSSVar('--chart-color-10', '#c4dd88'), getCSSVar('--chart-color-11', '#f5dc69'), getCSSVar('--chart-color-12', '#424141'), ], // Background colors (semi-transparent versions) chartBgColors: [ getCSSVar('--chart-bg-1', 'rgba(0, 105, 62, 0.5)'), getCSSVar('--chart-bg-2', 'rgba(38, 122, 186, 0.5)'), getCSSVar('--chart-bg-3', 'rgba(255, 160, 15, 0.5)'), getCSSVar('--chart-bg-4', 'rgba(157, 22, 46, 0.5)'), getCSSVar('--chart-bg-5', 'rgba(138, 105, 150, 0.5)'), getCSSVar('--chart-bg-6', 'rgba(165, 215, 95, 0.5)'), ], // Semantic colors positive: getCSSVar('--chart-positive', '#00693e'), negative: getCSSVar('--chart-negative', '#9d162e'), neutral: getCSSVar('--chart-neutral', '#424141'), highlight: getCSSVar('--chart-highlight', '#ffa00f'), // Grid and axis colors gridLight: getCSSVar('--chart-grid-light', 'rgba(0, 105, 62, 0.1)'), gridMedium: getCSSVar('--chart-grid-medium', 'rgba(0, 105, 62, 0.15)'), gridDark: getCSSVar('--chart-grid-dark', 'rgba(0, 105, 62, 0.2)'), axisColor: getCSSVar('--chart-axis-color', '#0a2518'), // Font fontFamily: getCSSVar('--chart-font-family', "'Avenir LT Std', 'Avenir', 'Avenir Next', -apple-system, BlinkMacSystemFont, sans-serif"), }; } // Initialize palette (will be populated when DOM is ready) let CDL_PALETTE = null; // For convenience, expose primary chart colors array let CHART_COLORS = null; // ========================================================================== // FONT CONFIGURATION // Responsive font sizes based on typical Marp slide dimensions (1280x720) // ========================================================================== const FONT_CONFIG = { sizes: { title: 22, // Chart title subtitle: 18, // Subtitle legend: 16, // Legend labels axisTitle: 18, // Axis titles axisTicks: 16, // Axis tick labels tooltip: 14, // Tooltip text dataLabels: 14, // Data labels on charts }, weight: { normal: 400, medium: 500, bold: 600, }, }; // ========================================================================== // HELPER FUNCTIONS // ========================================================================== /** * Ensure palette is initialized */ function ensurePalette() { if (!CDL_PALETTE) { CDL_PALETTE = buildPaletteFromCSS(); CHART_COLORS = CDL_PALETTE.chartColors; } return CDL_PALETTE; } /** * Get color for a dataset at given index * Cycles through palette if more datasets than colors */ function getColor(index) { ensurePalette(); return CHART_COLORS[index % CHART_COLORS.length]; } /** * Get color with alpha transparency */ function getColorWithAlpha(color, alpha) { // Handle hex colors if (color.startsWith('#')) { const r = parseInt(color.slice(1, 3), 16); const g = parseInt(color.slice(3, 5), 16); const b = parseInt(color.slice(5, 7), 16); return `rgba(${r}, ${g}, ${b}, ${alpha})`; } // Handle rgba colors if (color.startsWith('rgba')) { return color.replace(/[\d.]+\)$/, `${alpha})`); } return color; } /** * Generate colors for all datasets in chart data * Automatically assigns colors if not specified */ function autoAssignColors(data, chartType) { if (!data || !data.datasets) return data; data.datasets.forEach((dataset, index) => { const baseColor = getColor(index); // Only assign colors if not already specified switch (chartType) { case 'bar': case 'horizontalBar': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'line': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.1); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 3; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 6; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } if (dataset.tension === undefined) { dataset.tension = 0.3; } break; case 'scatter': case 'bubble': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 15; } if (dataset.pointHoverRadius === undefined) { dataset.pointHoverRadius = 18; } break; case 'pie': case 'doughnut': case 'polarArea': // For pie charts, we need multiple colors for one dataset if (!dataset.backgroundColor) { const numItems = dataset.data ? dataset.data.length : 6; dataset.backgroundColor = []; for (let i = 0; i < numItems; i++) { dataset.backgroundColor.push(getColor(i)); } } if (!dataset.borderColor) { dataset.borderColor = '#d8d8d8'; // Slide background } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'radar': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.2); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 4; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } break; default: // Generic color assignment if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } } }); return data; } // ========================================================================== // CHART.JS GLOBAL DEFAULTS // ========================================================================== function applyGlobalDefaults() { if (typeof Chart === 'undefined') { console.warn('Chart.js not loaded. chart-defaults.js requires Chart.js to be loaded first.'); return false; } // Ensure palette is loaded from CSS const palette = ensurePalette(); // Font defaults Chart.defaults.font.family = palette.fontFamily; Chart.defaults.font.size = FONT_CONFIG.sizes.axisTicks; Chart.defaults.color = palette.textPrimary; // Responsive defaults Chart.defaults.responsive = true; Chart.defaults.maintainAspectRatio = false; // Animation (subtle) Chart.defaults.animation.duration = 400; // Plugin defaults // Legend Chart.defaults.plugins.legend.labels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, weight: FONT_CONFIG.weight.normal, }; Chart.defaults.plugins.legend.labels.color = palette.textPrimary; Chart.defaults.plugins.legend.labels.usePointStyle = true; Chart.defaults.plugins.legend.labels.padding = 20; // Title Chart.defaults.plugins.title.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.title, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.title.color = palette.textPrimary; // Tooltip Chart.defaults.plugins.tooltip.backgroundColor = palette.textPrimary; Chart.defaults.plugins.tooltip.titleFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.tooltip.bodyFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, }; Chart.defaults.plugins.tooltip.cornerRadius = 4; Chart.defaults.plugins.tooltip.padding = 10; // Scale defaults (for cartesian charts) // These need to be applied per-scale type const scaleDefaults = { grid: { color: palette.gridLight, lineWidth: 1, }, border: { color: palette.gridDark, width: 1, }, ticks: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }, color: palette.textPrimary, }, title: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTitle, weight: FONT_CONFIG.weight.normal, }, color: palette.textPrimary, }, }; // Apply scale defaults to linear scale if (Chart.defaults.scales && Chart.defaults.scales.linear) { if (Chart.defaults.scales.linear.grid) Object.assign(Chart.defaults.scales.linear.grid, scaleDefaults.grid); if (Chart.defaults.scales.linear.border) Object.assign(Chart.defaults.scales.linear.border, scaleDefaults.border); if (Chart.defaults.scales.linear.ticks) Object.assign(Chart.defaults.scales.linear.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.linear.title) Object.assign(Chart.defaults.scales.linear.title, scaleDefaults.title); } // Apply scale defaults to category scale if (Chart.defaults.scales && Chart.defaults.scales.category) { if (Chart.defaults.scales.category.grid) Object.assign(Chart.defaults.scales.category.grid, scaleDefaults.grid); if (Chart.defaults.scales.category.border) Object.assign(Chart.defaults.scales.category.border, scaleDefaults.border); if (Chart.defaults.scales.category.ticks) Object.assign(Chart.defaults.scales.category.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.category.title) Object.assign(Chart.defaults.scales.category.title, scaleDefaults.title); } // Apply scale defaults to logarithmic scale if (Chart.defaults.scales && Chart.defaults.scales.logarithmic) { if (Chart.defaults.scales.logarithmic.grid) Object.assign(Chart.defaults.scales.logarithmic.grid, scaleDefaults.grid); if (Chart.defaults.scales.logarithmic.border) Object.assign(Chart.defaults.scales.logarithmic.border, scaleDefaults.border); if (Chart.defaults.scales.logarithmic.ticks) Object.assign(Chart.defaults.scales.logarithmic.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.logarithmic.title) Object.assign(Chart.defaults.scales.logarithmic.title, scaleDefaults.title); } // Apply scale defaults to radial scale (for radar charts) if (Chart.defaults.scales && Chart.defaults.scales.radialLinear) { if (Chart.defaults.scales.radialLinear.grid) Chart.defaults.scales.radialLinear.grid.color = palette.gridLight; if (Chart.defaults.scales.radialLinear.angleLines) Chart.defaults.scales.radialLinear.angleLines.color = palette.gridMedium; if (Chart.defaults.scales.radialLinear.pointLabels) { Chart.defaults.scales.radialLinear.pointLabels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }; Chart.defaults.scales.radialLinear.pointLabels.color = palette.textPrimary; } } return true; } // ========================================================================== // CHART WRAPPER FOR AUTO-STYLING // ========================================================================== /** * Wrap the Chart constructor to automatically apply CDL styling */ function wrapChartConstructor() { if (typeof Chart === 'undefined') return; const OriginalChart = Chart; // Create a wrapper that auto-applies colors window.Chart = function(ctx, config) { // Auto-assign colors if not specified if (config && config.data) { config.data = autoAssignColors(config.data, config.type); } // Merge default options for specific chart types if (config && config.options) { config.options = applyChartTypeDefaults(config.type, config.options); } // Call original constructor return new OriginalChart(ctx, config); }; // Copy static properties and methods Object.setPrototypeOf(window.Chart, OriginalChart); Object.assign(window.Chart, OriginalChart); // Preserve the prototype chain window.Chart.prototype = OriginalChart.prototype; } /** * Apply chart-type specific defaults */ function applyChartTypeDefaults(chartType, userOptions) { const options = { ...userOptions }; switch (chartType) { case 'bar': case 'horizontalBar': // Bar chart defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; // Hide x-axis grid for cleaner look if (options.scales.x.grid === undefined) { options.scales.x.grid = { display: false }; } break; case 'line': // Line chart defaults if (!options.interaction) { options.interaction = { intersect: false, mode: 'index' }; } break; case 'pie': case 'doughnut': // Pie/doughnut defaults if (!options.plugins) options.plugins = {}; if (options.plugins.legend === undefined) { const palette = ensurePalette(); options.plugins.legend = { position: 'right', labels: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, }, color: palette.textPrimary, padding: 15, }, }; } break; case 'radar': // Radar chart defaults - keep as-is, scale defaults applied globally break; case 'scatter': case 'bubble': // Scatter/bubble defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; break; } return options; } // ========================================================================== // CONVENIENCE FUNCTIONS FOR USERS // Exposed on window.CDLChart for easy access // ========================================================================== window.CDLChart = { // Color palette access (getters to ensure lazy initialization) get colors() { return ensurePalette().chartColors; }, get palette() { return ensurePalette(); }, // Get specific color by index getColor: getColor, // Get color with transparency getColorWithAlpha: getColorWithAlpha, // Get array of colors for a specific count getColors: function(count) { ensurePalette(); const result = []; for (let i = 0; i < count; i++) { result.push(getColor(i)); } return result; }, // Font configuration fonts: FONT_CONFIG, // Quick chart creation helpers // These create minimal config that auto-applies all styling /** * Create a simple bar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ bar: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'bar', data: { labels: labels, datasets: [{ data: data }], }, options: { plugins: { legend: { display: false } }, ...options, }, }); }, /** * Create a simple line chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ line: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'line', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, fill: ds.fill !== undefined ? ds.fill : true, })), }, options: options, }); }, /** * Create a simple pie chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ pie: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'pie', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a simple scatter chart * @param {string} canvasId - Canvas element ID * @param {Array} datasets - Array of {label, data: [{x, y}]} objects * @param {object} options - Optional overrides */ scatter: function(canvasId, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'scatter', data: { datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, /** * Create a doughnut chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ doughnut: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'doughnut', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a radar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ radar: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'radar', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, }; // ========================================================================== // INITIALIZATION // ========================================================================== function initialize() { // Wait for Chart.js to be available if (typeof Chart !== 'undefined') { applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully.'); return true; } else { // Chart.js not yet loaded - wait and retry let retries = 0; const maxRetries = 50; // 5 seconds max wait const checkInterval = setInterval(function() { retries++; if (typeof Chart !== 'undefined') { clearInterval(checkInterval); applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully (after waiting for Chart.js).'); } else if (retries >= maxRetries) { clearInterval(checkInterval); console.warn('Chart.js not found after waiting. CDL Chart defaults not applied.'); } }, 100); return false; } } // Initialize IMMEDIATELY - this must run BEFORE any chart creation scripts // Chart.js CDN should be loaded before this script initialize(); })();
Step-by-step routing for a single token:
1import torch.nn.functional as F
2
3def route_token(x, router_weights, num_experts=8, k=2):
4 """Route a token to top-k experts"""
5 # Step 1: Compute routing scores (linear projection)
6 # router_weights: (hidden_dim, num_experts)
7 logits = x @ router_weights # Shape: (num_experts,)
8 # Example: [-0.5, 2.1, 1.3, 0.2, -0.1, 0.0, -0.3, 0.1]
9
10 # Step 2: Convert to probabilities
11 probs = F.softmax(logits, dim=-1)
12 # Example: [0.04, 0.52, 0.24, 0.08, 0.03, 0.03, 0.03, 0.03]
13
14 # Step 3: Select top-k experts
15 top_k_probs, top_k_indices = torch.topk(probs, k)
16 # indices: [1, 2], probs: [0.52, 0.24]
17
18 # Step 4: Normalize selected probabilities
19 top_k_probs = top_k_probs / top_k_probs.sum()
20 # [0.68, 0.32] # Now sums to 1
21
22 return top_k_indices, top_k_probs
23 # Expert 1 handles 68%, Expert 2 handles 32%
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class MoELayer(nn.Module):
6 def __init__(self, d_model, num_experts, expert_capacity, k=2):
7 super().__init__()
8 self.num_experts = num_experts
9 self.k = k # Top-k routing
10
11 # Router
12 self.gate = nn.Linear(d_model, num_experts)
13
14 # Experts (simple FFN for each)
15 self.experts = nn.ModuleList([
16 nn.Sequential(
17 nn.Linear(d_model, 4 * d_model),
18 nn.ReLU(),
19 nn.Linear(4 * d_model, d_model)
20 )
21 for _ in range(num_experts)
22 ])
23
24 def forward(self, x):
25 # x: (batch_size, seq_len, d_model)
26 batch_size, seq_len, d_model = x.shape
27
28 # Compute routing scores
29 router_logits = self.gate(x) # (batch, seq, num_experts)
30 router_probs = F.softmax(router_logits, dim=-1)
31
32 # Select top-k experts
33 top_k_probs, top_k_indices = torch.topk(router_probs, self.k, dim=-1)
34 # top_k_probs: (batch, seq, k)
35 # top_k_indices: (batch, seq, k)
1# Initialize output
2 output = torch.zeros_like(x)
3
4 # Route to experts
5 for i in range(self.k):
6 # Get expert indices for this position
7 expert_idx = top_k_indices[:, :, i] # (batch, seq)
8 expert_weight = top_k_probs[:, :, i] # (batch, seq)
9
10 # Process through each expert
11 for expert_id in range(self.num_experts):
12 # Mask for tokens routed to this expert
13 mask = (expert_idx == expert_id)
14
15 if mask.any():
16 # Get tokens for this expert
17 expert_input = x[mask]
18
19 # Process through expert
20 expert_output = self.experts[expert_id](expert_input)
21
22 # Add weighted output
23 output[mask] += expert_weight[mask].unsqueeze(-1) * expert_output
24
25 return output
Note: This is simplified. Production implementations handle batching and load balancing more efficiently.
Problem: Without balancing, some experts get overused!
1Expert Usage Distribution (Unbalanced):
2
3E1: ████████████████████████████████ 40% ← Overloaded!
4E2: ██████████████████ 22%
5E3: ████████████ 15%
6E4: ████████ 10%
7E5: ████ 5%
8E6: ███ 4%
9E7: ██ 3% ← Undertrained
10E8: █ 1% ← "Dead" expert
Desired (Balanced):
1E1: ████████████ 12.5%
2E2: ████████████ 12.5%
3E3: ████████████ 12.5%
4... (all equal)
Why imbalance is bad:
Solution 1: Auxiliary Loss (Penalize Imbalance)
1def compute_load_balance_loss(router_probs, expert_assignments, alpha=0.01):
2 """Add penalty to main loss for imbalanced routing"""
3 num_experts = router_probs.shape[-1]
4
5 # f_i: fraction of tokens actually sent to expert i
6 tokens_per_expert = expert_assignments.sum(dim=0) # Count per expert
7 f = tokens_per_expert / tokens_per_expert.sum() # [0.4, 0.22, ...]
8
9 # P_i: average routing probability for expert i
10 P = router_probs.mean(dim=0) # [0.35, 0.2, ...]
11
12 # Loss: encourages f and P to both be uniform (1/N each)
13 aux_loss = alpha * num_experts * (f * P).sum()
14 return aux_loss # Added to main training loss
Solution 2: Expert Capacity Limits
1capacity = (batch_size * seq_len) // num_experts * capacity_factor # e.g., 1.25x
2# If Expert 1 already has `capacity` tokens, overflow goes to Expert 2
Challenges in training MoE:
Best Practice: Start with dense model, then convert to MoE for fine-tuning
MoE Memory Requirements:
All experts must be in memory, even if only 2 are active!
Example: Mixtral 8x7B
Solutions:
A state-of-the-art sparse MoE model from Mistral AI with 8 experts, each 7B parameters.
Architecture:
Training:
Reference: Jiang et al. (2024) - "Mixtral of Experts"
Comparison with dense models:
| Model | Total Params | Active Params | MMLU Score | Speed vs 70B |
|---|---|---|---|---|
| Llama 2 13B | 13B | 13B | 55.0 | 5× faster |
| Mixtral 8x7B | 47B | 13B | 70.6 | 5× faster |
| Llama 2 70B | 70B | 70B | 69.7 | 1× (baseline) |
| GPT-3.5 | ~175B | ~175B | 70.0 | N/A (API) |
The Magic of MoE:
1Mixtral achieves:
2┌─────────────────────────────────────────────┐
3│ Quality of 70B model (MMLU: 70.6 vs 69.7) │
4│ Speed of 13B model (only 13B active) │
5│ = Best of both worlds! │
6└─────────────────────────────────────────────┘
Still need memory for all 47B params. Savings are in compute, not VRAM.
Experts naturally specialize without explicit supervision!
Analyzed routing patterns in Mixtral:
1Token Type → Most Active Experts
2─────────────────────────────────────────────
3"def", "class", "import" → Expert 2 (Code)
4"la", "le", "français" → Expert 5 (French)
5"∑", "∫", "theorem" → Expert 7 (Math)
6"the", "is", "and" → Expert 1 (Common words)
7"neural", "gradient" → Expert 3 (Technical)
Concrete Example: Sentence routing
1"The neural network learns via backpropagation"
2 │ │ │ │ │ │
3 E1 E3 E3 E1 E1 E3
4
5Different tokens in the same sentence use different experts!
Specialization is emergent. Nobody told Expert 2 to handle code!
Beyond MoE: Making models smaller and faster
1. Quantization
1# Original: 32-bit float (4 bytes/param)
2weight = 0.123456789 # Full precision
3
4# INT8: 8-bit integer (1 byte/param)
5weight_int8 = 31 # Scaled + quantized
6# 4× memory reduction!
7
8# INT4: 4-bit (0.5 byte/param)
9# 8× memory reduction!
Llama 2 7B Memory:
2. Knowledge Distillation
1# Teacher: Large model (GPT-4)
2# Student: Small model (GPT-2)
3
4teacher_output = gpt4(input)
5student_output = gpt2(input)
6
7# Train student to match teacher
8loss = KL_div(student_output,
9 teacher_output)
Results:
Making generation faster:
1. KV Cache (Essential)
1# Without cache: Recompute all attention
2# Token 100 attends to tokens 1-99
3# = O(n²) attention per token!
4
5# With cache: Store previous K,V
6cache = {}
7for token in sequence:
8 k, v = compute_kv(token)
9 cache[pos] = (k, v) # Store!
10 # Only compute attention once
2. Flash Attention
1Standard: Load full attention matrix
2FlashAtt: Tiled, memory-efficient
3→ 2-4× faster, fits longer sequences
3. Speculative Decoding
1# Draft model (fast, small): 7B
2draft_tokens = small_model.generate(5)
3# ["The", "cat", "sat", "on", "mat"]
4
5# Target model (slow, large): 70B
6verified = large_model.verify(draft_tokens)
7# ["The", "cat", "sat", "on", "the"]
8# ✓ ✓ ✓ ✓ ✗
9
10# Accept 4/5 in one batch!
11# 2-3× speedup, same quality
4. Continuous Batching (vLLM)
Alternative to Transformers with linear-time complexity
Transformer Attention: O(n²)
1Sequence length: 1K 4K 16K 64K
2Compute (relative): 1 16 256 4096
3 ↑
4 Gets expensive fast!
Mamba SSM: O(n)
1Sequence length: 1K 4K 16K 64K
2Compute (relative): 1 4 16 64
3 ↑
4 Linear scaling!
How SSMs Work (Simplified)
1# State evolves with each token
2state = initial_state # Fixed size!
3for token in sequence:
4 # Update state (no attention)
5 state = A @ state + B @ token
6 output = C @ state
7 # Always O(1) per token!
Key Insight:
Reference: Gu & Dao (2023) - "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
Choosing the Right Technique for Your Use Case:
1 Quality
2 ↑
3 GPT-4 (175B) │ ● Dense Large
4 │
5 Mixtral (47B/13B) │ ● MoE
6 │
7 Llama-7B-Q4 │ ● Quantized
8 │
9 DistilGPT-2 │ ● Distilled
10 │
11 └──────────────────────→ Speed/Cost
| Technique | Best For | Trade-off |
|---|---|---|
| Dense Large | Maximum quality | Expensive, slow |
| MoE | Quality + speed | High memory |
| Quantization | Edge deployment | Slight quality loss |
| Distillation | Fixed tasks | Requires teacher |
| Pruning | Latency-critical | Irreversible |
Need max quality? → Dense. Need speed on GPU? → MoE. Need to run locally? → Quantized.
The carbon cost of large language models:
| GPT-3 | 502 | 112 cars for 1 year |
|---|---|---|
| Llama 2 | ~539 | 120 cars for 1 year |
Factors affecting carbon footprint:
Source: Strubell et al. (2019) - "Energy and Policy Considerations for Deep Learning in NLP"
Should powerful AI be accessible to everyone, or only large organizations?
Current reality:
Efficiency enables democratization:
Make AI accessible to researchers, startups, and developing nations—not just big tech!
Closed (GPT-4, Claude):
Open (Llama, Mixtral):
The debate:
Where is the field heading?
Required:
Recommended:
Questions?
Next: Ethics, Bias, and Safety!