* 2. Include this script: * 3. Create charts with minimal configuration - colors are auto-applied! */ (function() { 'use strict'; // ========================================================================== // READ COLORS FROM CSS CUSTOM PROPERTIES // This ensures chart colors stay in sync with the theme // ========================================================================== /** * Get a CSS custom property value from :root */ function getCSSVar(name, fallback = '') { if (typeof getComputedStyle === 'undefined') return fallback; const value = getComputedStyle(document.documentElement).getPropertyValue(name).trim(); return value || fallback; } /** * Build palette from CSS custom properties (with fallbacks) */ function buildPaletteFromCSS() { return { // Primary brand colors dartmouthGreen: getCSSVar('--dartmouth-green', '#00693e'), textPrimary: getCSSVar('--text-primary', '#0a2518'), textSecondary: getCSSVar('--text-secondary', '#0a3d23'), // Chart colors (from CSS --chart-color-N variables) chartColors: [ getCSSVar('--chart-color-1', '#00693e'), getCSSVar('--chart-color-2', '#267aba'), getCSSVar('--chart-color-3', '#ffa00f'), getCSSVar('--chart-color-4', '#9d162e'), getCSSVar('--chart-color-5', '#8a6996'), getCSSVar('--chart-color-6', '#a5d75f'), getCSSVar('--chart-color-7', '#003c73'), getCSSVar('--chart-color-8', '#d94415'), getCSSVar('--chart-color-9', '#643c20'), getCSSVar('--chart-color-10', '#c4dd88'), getCSSVar('--chart-color-11', '#f5dc69'), getCSSVar('--chart-color-12', '#424141'), ], // Background colors (semi-transparent versions) chartBgColors: [ getCSSVar('--chart-bg-1', 'rgba(0, 105, 62, 0.5)'), getCSSVar('--chart-bg-2', 'rgba(38, 122, 186, 0.5)'), getCSSVar('--chart-bg-3', 'rgba(255, 160, 15, 0.5)'), getCSSVar('--chart-bg-4', 'rgba(157, 22, 46, 0.5)'), getCSSVar('--chart-bg-5', 'rgba(138, 105, 150, 0.5)'), getCSSVar('--chart-bg-6', 'rgba(165, 215, 95, 0.5)'), ], // Semantic colors positive: getCSSVar('--chart-positive', '#00693e'), negative: getCSSVar('--chart-negative', '#9d162e'), neutral: getCSSVar('--chart-neutral', '#424141'), highlight: getCSSVar('--chart-highlight', '#ffa00f'), // Grid and axis colors gridLight: getCSSVar('--chart-grid-light', 'rgba(0, 105, 62, 0.1)'), gridMedium: getCSSVar('--chart-grid-medium', 'rgba(0, 105, 62, 0.15)'), gridDark: getCSSVar('--chart-grid-dark', 'rgba(0, 105, 62, 0.2)'), axisColor: getCSSVar('--chart-axis-color', '#0a2518'), // Font fontFamily: getCSSVar('--chart-font-family', "'Avenir LT Std', 'Avenir', 'Avenir Next', -apple-system, BlinkMacSystemFont, sans-serif"), }; } // Initialize palette (will be populated when DOM is ready) let CDL_PALETTE = null; // For convenience, expose primary chart colors array let CHART_COLORS = null; // ========================================================================== // FONT CONFIGURATION // Responsive font sizes based on typical Marp slide dimensions (1280x720) // ========================================================================== const FONT_CONFIG = { sizes: { title: 22, // Chart title subtitle: 18, // Subtitle legend: 16, // Legend labels axisTitle: 18, // Axis titles axisTicks: 16, // Axis tick labels tooltip: 14, // Tooltip text dataLabels: 14, // Data labels on charts }, weight: { normal: 400, medium: 500, bold: 600, }, }; // ========================================================================== // HELPER FUNCTIONS // ========================================================================== /** * Ensure palette is initialized */ function ensurePalette() { if (!CDL_PALETTE) { CDL_PALETTE = buildPaletteFromCSS(); CHART_COLORS = CDL_PALETTE.chartColors; } return CDL_PALETTE; } /** * Get color for a dataset at given index * Cycles through palette if more datasets than colors */ function getColor(index) { ensurePalette(); return CHART_COLORS[index % CHART_COLORS.length]; } /** * Get color with alpha transparency */ function getColorWithAlpha(color, alpha) { // Handle hex colors if (color.startsWith('#')) { const r = parseInt(color.slice(1, 3), 16); const g = parseInt(color.slice(3, 5), 16); const b = parseInt(color.slice(5, 7), 16); return `rgba(${r}, ${g}, ${b}, ${alpha})`; } // Handle rgba colors if (color.startsWith('rgba')) { return color.replace(/[\d.]+\)$/, `${alpha})`); } return color; } /** * Generate colors for all datasets in chart data * Automatically assigns colors if not specified */ function autoAssignColors(data, chartType) { if (!data || !data.datasets) return data; data.datasets.forEach((dataset, index) => { const baseColor = getColor(index); // Only assign colors if not already specified switch (chartType) { case 'bar': case 'horizontalBar': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'line': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.1); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 3; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 6; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } if (dataset.tension === undefined) { dataset.tension = 0.3; } break; case 'scatter': case 'bubble': if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 15; } if (dataset.pointHoverRadius === undefined) { dataset.pointHoverRadius = 18; } break; case 'pie': case 'doughnut': case 'polarArea': // For pie charts, we need multiple colors for one dataset if (!dataset.backgroundColor) { const numItems = dataset.data ? dataset.data.length : 6; dataset.backgroundColor = []; for (let i = 0; i < numItems; i++) { dataset.backgroundColor.push(getColor(i)); } } if (!dataset.borderColor) { dataset.borderColor = '#d8d8d8'; // Slide background } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } break; case 'radar': if (!dataset.borderColor) { dataset.borderColor = baseColor; } if (!dataset.backgroundColor) { dataset.backgroundColor = getColorWithAlpha(baseColor, 0.2); } if (dataset.borderWidth === undefined) { dataset.borderWidth = 2; } if (dataset.pointRadius === undefined) { dataset.pointRadius = 4; } if (!dataset.pointBackgroundColor) { dataset.pointBackgroundColor = baseColor; } break; default: // Generic color assignment if (!dataset.backgroundColor) { dataset.backgroundColor = baseColor; } if (!dataset.borderColor) { dataset.borderColor = baseColor; } } }); return data; } // ========================================================================== // CHART.JS GLOBAL DEFAULTS // ========================================================================== function applyGlobalDefaults() { if (typeof Chart === 'undefined') { console.warn('Chart.js not loaded. chart-defaults.js requires Chart.js to be loaded first.'); return false; } // Ensure palette is loaded from CSS const palette = ensurePalette(); // Font defaults Chart.defaults.font.family = palette.fontFamily; Chart.defaults.font.size = FONT_CONFIG.sizes.axisTicks; Chart.defaults.color = palette.textPrimary; // Responsive defaults Chart.defaults.responsive = true; Chart.defaults.maintainAspectRatio = false; // Animation (subtle) Chart.defaults.animation.duration = 400; // Plugin defaults // Legend Chart.defaults.plugins.legend.labels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, weight: FONT_CONFIG.weight.normal, }; Chart.defaults.plugins.legend.labels.color = palette.textPrimary; Chart.defaults.plugins.legend.labels.usePointStyle = true; Chart.defaults.plugins.legend.labels.padding = 20; // Title Chart.defaults.plugins.title.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.title, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.title.color = palette.textPrimary; // Tooltip Chart.defaults.plugins.tooltip.backgroundColor = palette.textPrimary; Chart.defaults.plugins.tooltip.titleFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, weight: FONT_CONFIG.weight.medium, }; Chart.defaults.plugins.tooltip.bodyFont = { family: palette.fontFamily, size: FONT_CONFIG.sizes.tooltip, }; Chart.defaults.plugins.tooltip.cornerRadius = 4; Chart.defaults.plugins.tooltip.padding = 10; // Scale defaults (for cartesian charts) // These need to be applied per-scale type const scaleDefaults = { grid: { color: palette.gridLight, lineWidth: 1, }, border: { color: palette.gridDark, width: 1, }, ticks: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }, color: palette.textPrimary, }, title: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTitle, weight: FONT_CONFIG.weight.normal, }, color: palette.textPrimary, }, }; // Apply scale defaults to linear scale if (Chart.defaults.scales && Chart.defaults.scales.linear) { if (Chart.defaults.scales.linear.grid) Object.assign(Chart.defaults.scales.linear.grid, scaleDefaults.grid); if (Chart.defaults.scales.linear.border) Object.assign(Chart.defaults.scales.linear.border, scaleDefaults.border); if (Chart.defaults.scales.linear.ticks) Object.assign(Chart.defaults.scales.linear.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.linear.title) Object.assign(Chart.defaults.scales.linear.title, scaleDefaults.title); } // Apply scale defaults to category scale if (Chart.defaults.scales && Chart.defaults.scales.category) { if (Chart.defaults.scales.category.grid) Object.assign(Chart.defaults.scales.category.grid, scaleDefaults.grid); if (Chart.defaults.scales.category.border) Object.assign(Chart.defaults.scales.category.border, scaleDefaults.border); if (Chart.defaults.scales.category.ticks) Object.assign(Chart.defaults.scales.category.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.category.title) Object.assign(Chart.defaults.scales.category.title, scaleDefaults.title); } // Apply scale defaults to logarithmic scale if (Chart.defaults.scales && Chart.defaults.scales.logarithmic) { if (Chart.defaults.scales.logarithmic.grid) Object.assign(Chart.defaults.scales.logarithmic.grid, scaleDefaults.grid); if (Chart.defaults.scales.logarithmic.border) Object.assign(Chart.defaults.scales.logarithmic.border, scaleDefaults.border); if (Chart.defaults.scales.logarithmic.ticks) Object.assign(Chart.defaults.scales.logarithmic.ticks, scaleDefaults.ticks); if (Chart.defaults.scales.logarithmic.title) Object.assign(Chart.defaults.scales.logarithmic.title, scaleDefaults.title); } // Apply scale defaults to radial scale (for radar charts) if (Chart.defaults.scales && Chart.defaults.scales.radialLinear) { if (Chart.defaults.scales.radialLinear.grid) Chart.defaults.scales.radialLinear.grid.color = palette.gridLight; if (Chart.defaults.scales.radialLinear.angleLines) Chart.defaults.scales.radialLinear.angleLines.color = palette.gridMedium; if (Chart.defaults.scales.radialLinear.pointLabels) { Chart.defaults.scales.radialLinear.pointLabels.font = { family: palette.fontFamily, size: FONT_CONFIG.sizes.axisTicks, }; Chart.defaults.scales.radialLinear.pointLabels.color = palette.textPrimary; } } return true; } // ========================================================================== // CHART WRAPPER FOR AUTO-STYLING // ========================================================================== /** * Wrap the Chart constructor to automatically apply CDL styling */ function wrapChartConstructor() { if (typeof Chart === 'undefined') return; const OriginalChart = Chart; // Create a wrapper that auto-applies colors window.Chart = function(ctx, config) { // Auto-assign colors if not specified if (config && config.data) { config.data = autoAssignColors(config.data, config.type); } // Merge default options for specific chart types if (config && config.options) { config.options = applyChartTypeDefaults(config.type, config.options); } // Call original constructor return new OriginalChart(ctx, config); }; // Copy static properties and methods Object.setPrototypeOf(window.Chart, OriginalChart); Object.assign(window.Chart, OriginalChart); // Preserve the prototype chain window.Chart.prototype = OriginalChart.prototype; } /** * Apply chart-type specific defaults */ function applyChartTypeDefaults(chartType, userOptions) { const options = { ...userOptions }; switch (chartType) { case 'bar': case 'horizontalBar': // Bar chart defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; // Hide x-axis grid for cleaner look if (options.scales.x.grid === undefined) { options.scales.x.grid = { display: false }; } break; case 'line': // Line chart defaults if (!options.interaction) { options.interaction = { intersect: false, mode: 'index' }; } break; case 'pie': case 'doughnut': // Pie/doughnut defaults if (!options.plugins) options.plugins = {}; if (options.plugins.legend === undefined) { const palette = ensurePalette(); options.plugins.legend = { position: 'right', labels: { font: { family: palette.fontFamily, size: FONT_CONFIG.sizes.legend, }, color: palette.textPrimary, padding: 15, }, }; } break; case 'radar': // Radar chart defaults - keep as-is, scale defaults applied globally break; case 'scatter': case 'bubble': // Scatter/bubble defaults if (!options.scales) options.scales = {}; if (!options.scales.x) options.scales.x = {}; if (!options.scales.y) options.scales.y = {}; break; } return options; } // ========================================================================== // CONVENIENCE FUNCTIONS FOR USERS // Exposed on window.CDLChart for easy access // ========================================================================== window.CDLChart = { // Color palette access (getters to ensure lazy initialization) get colors() { return ensurePalette().chartColors; }, get palette() { return ensurePalette(); }, // Get specific color by index getColor: getColor, // Get color with transparency getColorWithAlpha: getColorWithAlpha, // Get array of colors for a specific count getColors: function(count) { ensurePalette(); const result = []; for (let i = 0; i < count; i++) { result.push(getColor(i)); } return result; }, // Font configuration fonts: FONT_CONFIG, // Quick chart creation helpers // These create minimal config that auto-applies all styling /** * Create a simple bar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ bar: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'bar', data: { labels: labels, datasets: [{ data: data }], }, options: { plugins: { legend: { display: false } }, ...options, }, }); }, /** * Create a simple line chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - X-axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ line: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'line', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, fill: ds.fill !== undefined ? ds.fill : true, })), }, options: options, }); }, /** * Create a simple pie chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ pie: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'pie', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a simple scatter chart * @param {string} canvasId - Canvas element ID * @param {Array} datasets - Array of {label, data: [{x, y}]} objects * @param {object} options - Optional overrides */ scatter: function(canvasId, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'scatter', data: { datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, /** * Create a doughnut chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Slice labels * @param {number[]} data - Data values * @param {object} options - Optional overrides */ doughnut: function(canvasId, labels, data, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'doughnut', data: { labels: labels, datasets: [{ data: data }], }, options: options, }); }, /** * Create a radar chart * @param {string} canvasId - Canvas element ID * @param {string[]} labels - Axis labels * @param {Array} datasets - Array of {label, data} objects * @param {object} options - Optional overrides */ radar: function(canvasId, labels, datasets, options = {}) { return new Chart(document.getElementById(canvasId), { type: 'radar', data: { labels: labels, datasets: datasets.map(ds => ({ label: ds.label, data: ds.data, })), }, options: options, }); }, }; // ========================================================================== // INITIALIZATION // ========================================================================== function initialize() { // Wait for Chart.js to be available if (typeof Chart !== 'undefined') { applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully.'); return true; } else { // Chart.js not yet loaded - wait and retry let retries = 0; const maxRetries = 50; // 5 seconds max wait const checkInterval = setInterval(function() { retries++; if (typeof Chart !== 'undefined') { clearInterval(checkInterval); applyGlobalDefaults(); wrapChartConstructor(); console.log('CDL Chart defaults applied successfully (after waiting for Chart.js).'); } else if (retries >= maxRetries) { clearInterval(checkInterval); console.warn('Chart.js not found after waiting. CDL Chart defaults not applied.'); } }, 100); return false; } } // Initialize IMMEDIATELY - this must run BEFORE any chart creation scripts // Chart.js CDN should be loaded before this script initialize(); })();

Lecture 7: Text Classification Workshop

PSYC 51.07: Models of language and communication

Jeremy R. Manning
Dartmouth College
Winter 2026

Learning Objectives

  1. Build text classifiers from scratch using scikit-learn
  2. Understand different text representation methods (BoW, TF-IDF, embeddings)
  3. Compare Naive Bayes, Logistic Regression, and Neural approaches
  4. Evaluate classifier performance using appropriate metrics
  5. Debug common issues in text classification pipelines

Hands-on coding with the 20 Newsgroups dataset

Workshop Overview

  1. Part 1: Loading and exploring real data
  2. Part 2: Feature engineering for text (BoW, TF-IDF)
  3. Part 3: Building classifiers (Naive Bayes, Logistic Regression, Neural Networks)
  4. Part 4: Model comparison and analysis
  5. Part 5: Error analysis and improvements
  6. Part 6: Real-world considerations (class imbalance)

xhour_classification_demo.ipynb

Part 1: The 20 Newsgroups Dataset

  • Posts from 20 different newsgroups
  • ~20,000 documents total
  • Good for learning classification fundamentals
  • sci.space — Science discussions about space
  • rec.sport.hockey — Sports discussions about hockey
  • talk.politics.misc — Political discussions
  • comp.graphics — Computer graphics

Relatively distinct topics for easier learning.

Loading the Data


1from sklearn.datasets import fetch_20newsgroups
2
3categories = [
4    'sci.space',
5    'rec.sport.hockey',
6    'talk.politics.misc',
7    'comp.graphics'
8]
9
10train_data = fetch_20newsgroups(
11    subset='train',
12    categories=categories,
13    shuffle=True,
14    random_state=42,
15    remove=('headers', 'footers', 'quotes')  # Remove metadata
16)
17
18print(f"Loaded {len(train_data.data)} training documents")

Always explore your data before building models

  1. How many documents per category?
  2. What do the documents look like?
  3. What words/phrases might be good indicators?
  4. Are there categories that might be hard to distinguish?

Exploring the Data: Concrete Example


1import pandas as pd
2from collections import Counter
3
4# Check class distribution
5print("Documents per category:")
6for i, name in enumerate(train_data.target_names):
7    count = (train_data.target == i).sum()
8    print(f"  {name}: {count}")
9# sci.space: 593, rec.sport.hockey: 600, talk.politics.misc: 465, comp.graphics: 584
10
11# Look at a sample document
12idx = [i for i, t in enumerate(train_data.target) if t == 0][0]
13print(train_data.data[idx][:500])
14# "NASA announced today that the Mars rover has discovered evidence of water..."

Classes are roughly balanced (good!), but talk.politics.misc has fewer examples.

Convert text to numbers for machine learning

  1. Bag of Words (BoW): Count word frequencies
  2. TF-IDF: Weight by document frequency
  3. Dense embeddings: (Preview for future lectures)

Bag of Words: count how many times each word appears


1from sklearn.feature_extraction.text import CountVectorizer
2
3bow_vectorizer = CountVectorizer(
4    max_features=5000,   # Keep only top 5000 words
5    min_df=2,            # Word must appear in at least 2 docs
6    max_df=0.8,          # Word must appear in <80% of docs
7    stop_words='english' # Remove common words
8)
9
10X_train_bow = bow_vectorizer.fit_transform(train_data.data)

Result: Sparse matrix of word counts

Bag of Words: Limitations

  • Word presence/frequency
  • Vocabulary overlap between documents
  • Word order ("not good" vs "good not")
  • Semantics ("great" vs "excellent")
  • Context

Common words dominate but are often uninformative!

BoW vectors are sparse: mostly zeros


1from sklearn.feature_extraction.text import CountVectorizer
2
3docs = ["NASA launches rocket to Mars", "Hockey game ends in overtime",
4        "NASA discovers water on Mars"]
5
6vectorizer = CountVectorizer()
7X = vectorizer.fit_transform(docs)
8
9print("Vocabulary:", vectorizer.vocabulary_)
10# {'nasa': 5, 'launches': 4, 'rocket': 7, 'to': 8, 'mars': 6, ...}
11
12print("\nDocument 1:", X[0].toarray())
13# [0 0 0 0 1 1 1 1 1 0 0 0 0] <- counts for each word

Most entries are 0 (sparse!). Documents share "mars" and "nasa".

Method 2: TF-IDF

TF-IDF(t,d)=TF(t,d)×IDF(t)\text{TF-IDF}(t, d) = \text{TF}(t, d) \times \text{IDF}(t)

where:

  • TF(t,d)\text{TF}(t, d) = frequency of term tt in document dd
  • IDF(t)=logNdf(t)\text{IDF}(t) = \log\frac{N}{\text{df}(t)} = inverse document frequency

Downweight common words, upweight rare informative words!

TF-IDF in Practice


1from sklearn.feature_extraction.text import TfidfVectorizer
2
3tfidf_vectorizer = TfidfVectorizer(
4    max_features=5000,
5    min_df=2,
6    max_df=0.8,
7    stop_words='english',
8    use_idf=True,
9    sublinear_tf=True  # Use log scaling for term frequency
10)
11
12X_train_tfidf = tfidf_vectorizer.fit_transform(train_data.data)

Result: Sparse matrix of TF-IDF scores

BoW vs TF-IDF Comparison

Word BoW Count TF-IDF Score
"the" 15 0.02 (low — common everywhere)
"nasa" 3 0.45 (high — rare, informative)
"space" 5 0.38 (moderate — distinctive)

TF-IDF identifies the truly distinctive terms!

Three classifier approaches to compare

  1. Naive Bayes: Fast, probabilistic, good baseline
  2. Logistic Regression: Linear, interpretable, often best
  3. Neural Network: Flexible, can learn complex patterns

Classifier 1: Naive Bayes

P(yx)P(y)i=1nP(xiy)P(y|x) \propto P(y) \prod_{i=1}^n P(x_i|y)

  • Why "naive"? Assumes features are independent (they're not!)
  • Why does it work? Despite the wrong assumption, it often performs well for text.

1from sklearn.naive_bayes import MultinomialNB
2
3nb = MultinomialNB()
4nb.fit(X_train_tfidf, train_data.target)

Classifier 2: Logistic Regression

P(y=kx)=ewkTxjewjTxP(y=k|x) = \frac{e^{w_k^T x}}{\sum_{j} e^{w_j^T x}}

  • Interpretable weights (which words matter?)
  • Often outperforms Naive Bayes
  • Fast training and prediction

1from sklearn.linear_model import LogisticRegression
2
3lr = LogisticRegression(max_iter=1000, C=1.0)
4lr.fit(X_train_tfidf, train_data.target)

Logistic regression weights reveal which words matter

Category Top Positive Features
sci.space nasa, orbit, shuttle, moon, launch
rec.sport.hockey hockey, nhl, team, game, play
talk.politics.misc government, president, tax, policy
comp.graphics image, graphics, 3d, rendering

The model learns what we'd expect! Interpretability matters.

Classifier 3: Simple Neural Network

Input (TF-IDF) Hidden Layer 1 (256) Hidden Layer 2 (128) Output (4 classes)
Feedforward architecture

1import torch.nn as nn
2
3class TextClassifier(nn.Module):
4    def __init__(self, input_dim, hidden_dim, output_dim):
5        super().__init__()
6        self.fc1 = nn.Linear(input_dim, hidden_dim)
7        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
8        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)
9        self.dropout = nn.Dropout(0.3)
10        self.relu = nn.ReLU()

Linear models are competitive with BoW features

Model Accuracy
Naive Bayes (BoW) ~85%
Naive Bayes (TF-IDF) ~87%
Logistic Regression ~90%
Neural Network ~89%

Neural networks shine with richer representations (embeddings), not BoW.

Accuracy alone is not enough

  • Precision: Of predicted positives, how many are truly positive?
  • Recall: Of actual positives, how many did we catch?
  • F1-Score: Harmonic mean: F1=2×Precision×RecallPrecision+RecallF1 = 2 \times \frac{Precision \times Recall}{Precision + Recall}

Datasets are imbalanced or different errors have different costs.

Confusion Matrix

Predicted A Predicted B Predicted C Predicted D
Actual A 85 2 3 0
Actual B 1 92 2 5
Actual C 4 3 88 5
Actual D 0 8 2 90

Diagonal = correct predictions. Off-diagonal = errors.

Error analysis: critical but often skipped

  1. Find misclassified examples
  2. Look for patterns
  3. Understand why the model failed
  4. Use insights to improve

Mixed topics, short documents, unusual vocabulary

Error Analysis: Concrete Example


1# Find misclassified examples
2y_pred = lr.predict(X_test_tfidf)
3errors = np.where(y_pred != test_data.target)[0]
4idx = errors[0]
5
6print(f"True: {test_data.target_names[test_data.target[idx]]}")
7print(f"Pred: {test_data.target_names[y_pred[idx]]}")
8print(test_data.data[idx][:300])

True: sci.space Predicted: comp.graphics

"I'm working on a 3D visualization of the solar system for my graphics project..."

Document mentions both graphics AND space. Model reasonably confused!

Common Error Patterns

  1. Ambiguous content: Document mentions multiple topics
  2. Limited context: Very short documents
  3. Domain shift: Test data differs from training
  4. Rare vocabulary: Important words not in training
  • Better preprocessing
  • More features (bigrams, trigrams)
  • Domain-specific fine-tuning

Class imbalance makes models predict the majority class

  1. Class weights: Penalize minority errors more
  2. Oversampling: Duplicate minority examples
  3. Undersampling: Remove majority examples
  4. SMOTE: Generate synthetic minority examples

1lr_balanced = LogisticRegression(
2    class_weight='balanced'  # Automatically adjust weights
3)

Discussion Questions

  1. BoW vs TF-IDF: When would you prefer one over the other?

  2. Linear vs Neural: Why didn't the neural network significantly outperform logistic regression?

  3. Feature Engineering: How important was feature engineering compared to model choice?

  4. Scalability: Which approach scales best to millions of documents?

  5. Interpretability: Which models are most interpretable? Why does it matter?

Key Takeaways

  1. Good features matter more than complex models (for many tasks)
  2. TF-IDF usually beats raw BoW for text classification
  3. Linear models are competitive with neural networks on bag-of-words
  4. Always examine errors to understand model behavior
  5. Consider class imbalance and adjust accordingly

Connection to Course Themes

Data Cleaning (Lecture 5) Tokenization (Lecture 6) Feature Extraction (Today) Classification (Today)

POS Tagging & Sentiment Analysis — How do these building blocks combine for real NLP applications?

Hands-On Exercise

xhour_classification_demo.ipynb

  1. Load the 20 Newsgroups dataset
  2. Experiment with different vectorizers
  3. Train multiple classifiers
  4. Analyze errors and improve
  5. Try your own text examples!

Build intuition for text classification

Additional Resources

  • 20 Newsgroups: Classic benchmark
  • IMDb Reviews: Sentiment classification
  • AG News: News categorization

Questions? Want to chat more?

📧 Email me
💬 Join our Discord
💁 Come to office hours

Lecture 8 — POS Tagging & Sentiment Analysis