Rust - Threads

Overview

Estimated time: 40–50 minutes

Master concurrent programming in Rust using threads. Learn how to spawn threads, share data safely, handle thread communication, and understand Rust's approach to thread safety without data races.

Learning Objectives

Prerequisites

Creating Threads

Basic Thread Spawning

Create threads using std::thread::spawn:

use std::thread;
use std::time::Duration;

fn main() {
    // Spawn a simple thread
    let handle = thread::spawn(|| {
        for i in 1..10 {
            println!("Thread: {}", i);
            thread::sleep(Duration::from_millis(1));
        }
    });
    
    // Main thread continues execution
    for i in 1..5 {
        println!("Main: {}", i);
        thread::sleep(Duration::from_millis(1));
    }
    
    // Wait for the spawned thread to finish
    handle.join().unwrap();
    println!("Both threads finished");
}

Thread with Return Values

Threads can return values that are collected when joined:

use std::thread;

fn main() {
    // Thread that returns a value
    let handle = thread::spawn(|| {
        let mut sum = 0;
        for i in 1..=100 {
            sum += i;
        }
        sum // Return value
    });
    
    // Do other work in main thread
    println!("Calculating sum in background...");
    
    // Get the result from the thread
    match handle.join() {
        Ok(result) => println!("Sum: {}", result),
        Err(e) => println!("Thread panicked: {:?}", e),
    }
    
    // Multiple threads with results
    let handles: Vec<_> = (0..5).map(|i| {
        thread::spawn(move || {
            let start = i * 20;
            let end = start + 20;
            (start..end).sum::()
        })
    }).collect();
    
    let results: Vec = handles.into_iter()
        .map(|h| h.join().unwrap())
        .collect();
    
    println!("Partial sums: {:?}", results);
    println!("Total: {}", results.iter().sum::());
}

Move Closures and Ownership

Transferring Ownership to Threads

Use move closures to transfer ownership of variables to threads:

use std::thread;

fn main() {
    let data = vec![1, 2, 3, 4, 5];
    let name = String::from("Worker");
    
    // Move ownership into the thread
    let handle = thread::spawn(move || {
        println!("Thread '{}' processing: {:?}", name, data);
        data.iter().sum::()
    });
    
    // data and name are no longer available in main thread
    // println!("{:?}", data); // This would cause a compile error
    
    let result = handle.join().unwrap();
    println!("Result: {}", result);
    
    // Example with multiple pieces of data
    let numbers = vec![10, 20, 30];
    let multiplier = 2;
    let prefix = "Result:".to_string();
    
    let handle = thread::spawn(move || {
        let sum: i32 = numbers.iter().sum();
        format!("{} {}", prefix, sum * multiplier)
    });
    
    println!("{}", handle.join().unwrap());
}

Cloning for Thread Safety

Clone data when multiple threads need access:

use std::thread;
use std::sync::Arc;

fn main() {
    let shared_data = Arc::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
    let mut handles = vec![];
    
    // Spawn multiple threads that all read the same data
    for i in 0..3 {
        let data_clone = Arc::clone(&shared_data);
        let handle = thread::spawn(move || {
            let chunk_size = data_clone.len() / 3;
            let start = i * chunk_size;
            let end = if i == 2 { data_clone.len() } else { start + chunk_size };
            
            let sum: i32 = data_clone[start..end].iter().sum();
            println!("Thread {} ({}..{}): {}", i, start, end, sum);
            sum
        });
        handles.push(handle);
    }
    
    // Collect results from all threads
    let results: Vec = handles.into_iter()
        .map(|h| h.join().unwrap())
        .collect();
    
    println!("Thread results: {:?}", results);
    println!("Total: {}", results.iter().sum::());
}

Thread Communication

Shared State with Mutex

Use Mutex for safe shared mutable state:

use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
    // Shared counter
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];
    
    for i in 0..5 {
        let counter_clone = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            for _ in 0..10 {
                let mut num = counter_clone.lock().unwrap();
                *num += 1;
                println!("Thread {} incremented counter to {}", i, *num);
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("Final counter value: {}", *counter.lock().unwrap());
    
    // Shared data structure
    let shared_list = Arc::new(Mutex::new(Vec::new()));
    let mut handles = vec![];
    
    for i in 0..3 {
        let list_clone = Arc::clone(&shared_list);
        let handle = thread::spawn(move || {
            for j in 0..5 {
                let mut list = list_clone.lock().unwrap();
                list.push(format!("Thread {}: Item {}", i, j));
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    let final_list = shared_list.lock().unwrap();
    println!("Final list ({} items):", final_list.len());
    for item in final_list.iter() {
        println!("  {}", item);
    }
}

Thread Patterns

Worker Pool Pattern

Create a pool of worker threads for processing tasks:

use std::sync::{Arc, Mutex, mpsc};
use std::thread;
use std::time::Duration;

struct Task {
    id: usize,
    data: String,
}

impl Task {
    fn process(&self) {
        println!("Processing task {}: {}", self.id, self.data);
        thread::sleep(Duration::from_millis(100)); // Simulate work
        println!("Task {} completed", self.id);
    }
}

fn main() {
    let (sender, receiver) = mpsc::channel();
    let receiver = Arc::new(Mutex::new(receiver));
    let mut workers = vec![];
    
    // Create worker threads
    for worker_id in 0..3 {
        let receiver_clone = Arc::clone(&receiver);
        let worker = thread::spawn(move || {
            loop {
                let task = {
                    let receiver = receiver_clone.lock().unwrap();
                    receiver.recv()
                };
                
                match task {
                    Ok(task) => {
                        println!("Worker {} got task {}", worker_id, task.id);
                        task.process();
                    }
                    Err(_) => {
                        println!("Worker {} shutting down", worker_id);
                        break;
                    }
                }
            }
        });
        workers.push(worker);
    }
    
    // Send tasks to workers
    for i in 0..10 {
        let task = Task {
            id: i,
            data: format!("Task data {}", i),
        };
        sender.send(task).unwrap();
    }
    
    // Close the channel to signal workers to shut down
    drop(sender);
    
    // Wait for all workers to finish
    for worker in workers {
        worker.join().unwrap();
    }
    
    println!("All tasks completed");
}

Producer-Consumer Pattern

Coordinate data production and consumption between threads:

use std::collections::VecDeque;
use std::sync::{Arc, Mutex, Condvar};
use std::thread;
use std::time::Duration;

struct Buffer {
    queue: Mutex>,
    not_empty: Condvar,
    not_full: Condvar,
    capacity: usize,
}

impl Buffer {
    fn new(capacity: usize) -> Self {
        Buffer {
            queue: Mutex::new(VecDeque::new()),
            not_empty: Condvar::new(),
            not_full: Condvar::new(),
            capacity,
        }
    }
    
    fn push(&self, item: T) {
        let mut queue = self.queue.lock().unwrap();
        
        // Wait while buffer is full
        while queue.len() >= self.capacity {
            queue = self.not_full.wait(queue).unwrap();
        }
        
        queue.push_back(item);
        self.not_empty.notify_one();
    }
    
    fn pop(&self) -> T {
        let mut queue = self.queue.lock().unwrap();
        
        // Wait while buffer is empty
        while queue.is_empty() {
            queue = self.not_empty.wait(queue).unwrap();
        }
        
        let item = queue.pop_front().unwrap();
        self.not_full.notify_one();
        item
    }
}

fn main() {
    let buffer = Arc::new(Buffer::new(5));
    
    // Producer thread
    let producer_buffer = Arc::clone(&buffer);
    let producer = thread::spawn(move || {
        for i in 0..20 {
            println!("Producing item {}", i);
            producer_buffer.push(i);
            thread::sleep(Duration::from_millis(50));
        }
        println!("Producer finished");
    });
    
    // Consumer threads
    let mut consumers = vec![];
    for consumer_id in 0..2 {
        let consumer_buffer = Arc::clone(&buffer);
        let consumer = thread::spawn(move || {
            for _ in 0..10 {
                let item = consumer_buffer.pop();
                println!("Consumer {} consumed item {}", consumer_id, item);
                thread::sleep(Duration::from_millis(100));
            }
            println!("Consumer {} finished", consumer_id);
        });
        consumers.push(consumer);
    }
    
    // Wait for all threads
    producer.join().unwrap();
    for consumer in consumers {
        consumer.join().unwrap();
    }
}

Thread Local Storage

Thread-Local Variables

Store data that's unique to each thread:

use std::cell::RefCell;
use std::thread;

thread_local! {
    static COUNTER: RefCell = RefCell::new(0);
}

fn increment_counter() {
    COUNTER.with(|c| {
        let mut counter = c.borrow_mut();
        *counter += 1;
        println!("Thread {:?}: Counter = {}", thread::current().id(), *counter);
    });
}

fn get_counter() -> u32 {
    COUNTER.with(|c| *c.borrow())
}

fn main() {
    let mut handles = vec![];
    
    for i in 0..3 {
        let handle = thread::spawn(move || {
            println!("Thread {} starting", i);
            
            // Each thread has its own counter
            for _ in 0..5 {
                increment_counter();
            }
            
            let final_count = get_counter();
            println!("Thread {} final count: {}", i, final_count);
        });
        handles.push(handle);
    }
    
    // Main thread also has its own counter
    increment_counter();
    increment_counter();
    println!("Main thread final count: {}", get_counter());
    
    for handle in handles {
        handle.join().unwrap();
    }
}

Error Handling in Threads

Handling Thread Panics

Deal with panics and errors in threaded code:

use std::thread;
use std::panic;

fn main() {
    // Thread that panics
    let handle = thread::spawn(|| {
        panic!("Something went wrong!");
    });
    
    match handle.join() {
        Ok(_) => println!("Thread completed successfully"),
        Err(e) => {
            println!("Thread panicked!");
            // The panic payload is an Any type
            if let Some(s) = e.downcast_ref::<&str>() {
                println!("Panic message: {}", s);
            }
        }
    }
    
    // Catching panics with custom handling
    let handles: Vec<_> = (0..5).map(|i| {
        thread::spawn(move || {
            if i == 2 {
                panic!("Thread {} panicked!", i);
            }
            format!("Thread {} completed", i)
        })
    }).collect();
    
    let mut results = vec![];
    let mut panics = vec![];
    
    for (i, handle) in handles.into_iter().enumerate() {
        match handle.join() {
            Ok(result) => results.push(result),
            Err(e) => {
                panics.push(i);
                println!("Thread {} panicked: {:?}", i, e);
            }
        }
    }
    
    println!("Successful results: {:?}", results);
    println!("Threads that panicked: {:?}", panics);
    
    // Using Result for recoverable errors
    let handle = thread::spawn(|| -> Result {
        let x = 10;
        let y = 0;
        
        if y == 0 {
            Err("Division by zero".to_string())
        } else {
            Ok(x / y)
        }
    });
    
    match handle.join() {
        Ok(Ok(result)) => println!("Division result: {}", result),
        Ok(Err(error)) => println!("Thread returned error: {}", error),
        Err(_) => println!("Thread panicked"),
    }
}

Performance Considerations

Thread Overhead and Best Practices

Understanding thread costs and optimization strategies:

use std::thread;
use std::time::Instant;

fn cpu_intensive_work(n: u64) -> u64 {
    (1..=n).sum()
}

fn main() {
    const WORK_SIZE: u64 = 1_000_000;
    const NUM_THREADS: usize = 4;
    
    // Sequential execution
    let start = Instant::now();
    let sequential_result = cpu_intensive_work(WORK_SIZE);
    let sequential_time = start.elapsed();
    
    println!("Sequential result: {} in {:?}", sequential_result, sequential_time);
    
    // Parallel execution
    let start = Instant::now();
    let chunk_size = WORK_SIZE / NUM_THREADS as u64;
    
    let handles: Vec<_> = (0..NUM_THREADS).map(|i| {
        thread::spawn(move || {
            let start = i as u64 * chunk_size + 1;
            let end = if i == NUM_THREADS - 1 { 
                WORK_SIZE 
            } else { 
                (i + 1) as u64 * chunk_size 
            };
            
            (start..=end).sum::()
        })
    }).collect();
    
    let parallel_result: u64 = handles.into_iter()
        .map(|h| h.join().unwrap())
        .sum();
    
    let parallel_time = start.elapsed();
    
    println!("Parallel result: {} in {:?}", parallel_result, parallel_time);
    println!("Speedup: {:.2}x", sequential_time.as_secs_f64() / parallel_time.as_secs_f64());
    
    // Thread creation overhead
    let start = Instant::now();
    let handles: Vec<_> = (0..1000).map(|i| {
        thread::spawn(move || i * 2)
    }).collect();
    
    let _results: Vec = handles.into_iter()
        .map(|h| h.join().unwrap())
        .collect();
    
    println!("1000 thread creations took: {:?}", start.elapsed());
    
    // Guidelines for thread usage
    println!("\nThread Usage Guidelines:");
    println!("- Use threads for CPU-intensive, parallelizable work");
    println!("- Thread creation has overhead - consider thread pools for many small tasks");
    println!("- Number of threads should typically match CPU cores for CPU-bound work");
    println!("- For I/O-bound work, you can use more threads than CPU cores");
}

Best Practices

Threading Guidelines

Common Pitfalls

Mistakes to Avoid

Checks for Understanding

  1. What happens if you don't call join() on a thread handle?
  2. When do you need to use move closures with threads?
  3. What's the difference between Arc<T> and Arc<Mutex<T>>?
  4. How do you determine the optimal number of threads for CPU-bound work?
Answers
  1. The main thread may exit before the spawned thread finishes, potentially terminating the program early
  2. When the thread needs to take ownership of variables from the enclosing scope
  3. Arc<T> allows shared read-only access; Arc<Mutex<T>> allows shared mutable access
  4. Generally, use the number of CPU cores available, which you can get with thread::available_parallelism()

← PreviousNext →