Sudoku backtracking in Rust

This is the first in a series of posts where I will explore the Rust programming language by using it to implement several different Sudoku solving algorithms. I think that this Japanese puzzle game is ideal for such exercise since it lends itself beautifully to a variety of distinct algorithmical approaches. My plan is to explore a new algorithm and a different aspect of the language in each of the three planned posts, starting with the more fundamental ones, and later evolving towards more advanced concepts.

First things first

Before we start with solving Sudokus, we'll need to define some core ideas, and also create a correctness checker for the puzzle, so that in the end we can be sure whether our program has actually found a solution.

First, we should define what a Sudoku is. This is the part where we will briefly touch upon Rust's rich type system. In our case we will represent the board as an array with $$9 \times 9 = 81$$ elements, that can be either a missing value or a digit. We make use of the built-in algebraic data type Option<u8> which can be exactly one of the values None or Some(x) at a time (where x is an unsigned 8-bit integer). The language will make sure that each case is always dealt with, so there is no possibility to forget about a missing value (compare this with NullPointerExceptions plaguing Java, or the garbage you get in C if you mistakenly try to treat special return values as correct function outputs!). In principle we could even define a Digit enumeration so that the compiler would enforce only values from 1 to 9, but I think this would be an overkill – forcing us to constantly convert between numbers and Digit and making code less readable. Finally, we wrap everything into a new type Sudoku defined as a new 1-element tuple (this is a newtype pattern, and allows us to implement our own methods for the type).

struct Sudoku([Option<u8>; 9 * 9]);

Rust is not an object-oriented language (at least, according to some definitions of object-oriented – e.g. it does not support inheritance). However, we can still define methods – functions associated with objects of a given type. We do so not in a struct body, but in independent implementation blocks.

impl Sudoku { 
}

Let's first define a helper method to make specifying Sudokus by hand as easy as writing a string literal. This will be a static method, i.e. it will not take any reference to self, and it will return an instance of our struct.

    fn new(desc: &str) -> Self {
    //  format: "123_5_789\n" x 9
        // `mut` keyword means that this variable can be mutated
        // Rust variables are immutable by default!
        let mut board : [Option<u8>; 9 * 9] = [None; 9 * 9];

        // In Rust all for loops work via iterators!
        for (line, row_index) in desc.split("\n").zip(0..9) {
            for (char, col_index) in line.chars().zip(0..9) {

                // Exhaustive pattern matching
                board[col_index + row_index * 9] = match char.to_digit(10) {
                    None => None,
                    Some(x) => Some(x as u8)
                }
            }
        }
        // Last value in a function is it's return value
        Sudoku(board)
    }

Another useful method – printing the board! We'll implement it in a separate implementation block for the Display trait. This is something similar to interfaces in other languages. Rust will know we agreed to the Display "contract" and it will use our method when we try to print an object of our type (using something like e.g. println!("{}", sudoku)).

impl std::fmt::Display for Sudoku {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut res: String = String::new();
        for y in 0..9 {
            res.extend((0..9).map(|x| match self.get(x, y) {
                Some(digit) => digit.to_string(),
                None => String::from("_")
            }));
            res.push_str("\n");
        }
        write!(f, "{}", res)
    }
}

Some convenience methods: (you could argue they're unnecessary but the unfortunate .0 – accessing the first element of the tuple – is just so ugly!) Here we will for the first time use an immutable reference to our struct (denoted by &), and a mutable, or unique reference (&mut). Rust's borrow checker enforces uniqueness of mutable references to provide memory-safety and race-safety without a garbage collector.

    fn get(&self, x: usize, y:usize) -> Option<u8> {
        self.0[x + 9 * y]
    }

    fn set(&mut self, x: usize, y: usize, val: u8) -> () {
        self.0[x + 9 * y] = Some(val);
    }

Now, to check correctness of the board we can make each row, column and block into a set of digits. If any of these has size different than nine, it means that either we miss a digit or we have multiple instances of the same digit. At that point it's OK to quickly return to report a mistake. Here we assume that no digit other than 1..9 will make it into our code.

    fn is_solved(&self) -> bool {
        for row_index in 0..9 {
            // Lambda function used as an argument to `map` method of an iterator
            // (And this language is just as performant as C. Amazing!)
            let different_digits : BTreeSet<_> = (0..9).map(|col_index| self.get(col_index, row_index))
            .filter(Option::is_some).collect();
            if different_digits.len() < 9 {
                return false;
            }
        }

        for col_index in 0..9 {
            let different_digits : BTreeSet<_> = (0..9).map(|row_index| self.get(col_index, row_index))
                .filter(Option::is_some).collect();
            if different_digits.len() < 9 {
                return false;
            }
        }

        for block_index in 0..9 {
            let bx = block_index / 3;
            let by = block_index % 3;
            let different_digits : BTreeSet<_> = (0..9)
                .map(|i| {
                    let x = i / 3;
                    let y = i % 3;
                    self.get(x + 3 * bx, y + 3 * by)
                })
                .filter(Option::is_some).collect();
            if different_digits.len() < 9 {
                return false;
            }
        }

        true
    }

To be sure that what we just wrote is actually correct, we need to make some basic tests! Fortunately, Rust toolchain has first-class support for unit tests. We'll test our code on a very simple case...

#[test]
fn solved_sudoku_is_solved() {
    let sudoku = Sudoku::new(
"123456789
456789123
789123456
234567891
567891234
891234567
345678912
678912345
912345678");
    assert!(sudoku.is_solved());
}

#[test]
fn wrong_sudoku_is_not_solved() {
    // Repeated '9' in the first row
    let sudoku = Sudoku::new(
"923456789
456789123
789123456
234567891
567891234
891234567
345678912
678912345
912345678");
    assert!(!sudoku.is_solved());
}

#[test]
fn sudoku_with_missing_digits_is_not_solved() {
    let sudoku = Sudoku::new(
"_23456789
456789123
789123456
234567891
5678912_4
891234567
345678912
678912345
912345678");
    assert!(!sudoku.is_solved());
}

Now we have all the equipment needed to explore all the different ways you can approach the Sudoku problem!

The backtracker

The first idea that comes to mind is to just brute-force our way through, i.e. try all possible combinations of digits and see which combination will be a valid solution to our puzzle. However, there are $$9^{9 \times 9} \approx \infty$$ possible combinations of digits in a 9 × 9 grid. On the other hand, we can from the start exclude all solutions that: - do not form a valid sudoku, or - are in conflict with the provided hints.

The idea of backtracking is a very simple one. This is just a brute-force algorithm that fills free spots in the grid one-by-one choosing digits in a deterministic order, and back-tracks a step as soon as it makes the grid invalid. When at some point the algorithm exhausted all possible digits for a given gridpoint, it goes back again.

Sudoku solved by backtracking

A brilliant illustration of what we are talking about here. Thanks to Simpsons contributor / CC BY-SA and Wikimedia Commons!

Surprisingly, this simple back-tracking approach is efficient enough to solve typical real-word "hard" sudoku within milliseconds on a modern laptop. So how do we go about implementing it in Rust?

This problem seems like it would lend itself nicely to recursion. So let's start with that. We have at most $9 \times 9 = 81$ fields on our board, so we should not run into a stack overflow (we will solve the board field-by-field starting in the top left and moving to the right or to the next row in each successive recursive call).

Let's begin with a helper method, and an enum representing state of the intermediate calculations that will help us later.

enum SudokuBruteSolveResult {
    Solved,
    NotYet
}

fn index_to_coords(index: usize) -> (usize, usize) {
    let x = index % 9;
    let y = index / 9;
    (x, y)
}

Now, let's define the function that changes our algorithm from a mindless brute-force into a sophisticated baque-traque; i.e. the function that will tell us which are the possible digits at the current position. We will need a reference to the sudoku board (here we won't modify it so an immutable one will suffice), as well as a way to keep track which place on it is currently being considered. For this second goal we'll use a sequentially increasing index – this is why we defined the index_to_coords above.

// add `use std::collections::{BTreeSet};` at the top of the file
// to use std-lib sets (implemented as a binaryt tree)

fn potential_digits(sudoku: &Sudoku, index: usize) -> BTreeSet<u8> {
    let (x, y) = index_to_coords(index);

    // We'll start with a full set and then filter "manually"
    let mut res = BTreeSet::from_iter(1..=9);

    for i in 0..9 {
        // No same digits in the same row!
        if i != x {
            match sudoku.get(i, y) {
                None => (),
                Some(other) => { res.remove(&other); }
            }
        }
    }

    for i in 0..9 {
        // No same digits in the same column!
        if i != y {
            match sudoku.get(x, i) {
                None => (),
                Some(other) => { res.remove(&other); }
            }
        }
    }

    // No same digits in the same block!
    let bx = x / 3;
    let by = y / 3;
    let local_x = x % 3;
    let local_y = y % 3;
    for other_x in 0..3 {
        for other_y in 0..3 {
            if other_x != local_x && other_y != local_y {
                match sudoku.get(bx * 3 + other_x, by * 3 + other_y) {
                    None => (),
                    Some(other) => { res.remove(&other); }
                }
            }
        }
    }

    res
}

And now the core of the program.

fn brute_solve_aux(sudoku: &mut Sudoku, index: usize) -> SudokuBruteSolveResult {
    // Base case: we went through the whole board!
    if index >= 9 * 9 {
        if sudoku.is_solved() {
            return Solved;
        }
        return NotYet;
    }

    let (x, y) = index_to_coords(index);

    let current = sudoku.get(x, y);
    match current {
        // If there's a digit already it means it must have been an original hint
        // since we clean up after ourselves when we backtrack.
        // We skip over it, since we can't modify it!
        Some(_digit) => brute_solve_aux(sudoku, index + 1),

        // Opportunity to insert a digit
        None => {
            for this_digit in potential_digits(sudoku, index) {
                sudoku.set(x, y, this_digit);
                match brute_solve_aux(sudoku, index + 1) {
                    // If our recursion comes back with a solution, we return!
                    // Otherwise, we try all the rest of possible digits until
                    // they are exhausted
                    Solved => { return Solved; }
                    _ => ()
                }
            }
            // Clean up after ourselves
            sudoku.clear(x, y);
            // Since at this point we did not find a solution, we return:
            NotYet
        }
    }
}

fn brute_solve(sudoku: &mut Sudoku) -> () {
    brute_solve_aux(sudoku, 0);
}

Well, does it work?

#[test]
fn brute_solve_test() {
    let mut sudoku = Sudoku::new("__8627__9
___5_____
_3__9____
__69__3_2
______95_
1__8_____
____52_63
4___8____
___3__24_");
    brute_solve(&mut sudoku);
    assert!(sudoku.is_solved());
}

Yay! It does!

println!("{}", sudoku);
518627439
269543781
734198526
856974312
347261958
192835674
971452863
423786195
685319247

Conclusions

Implementing the back-tracking sudoku solver in Rust turned out to be easier than I thought it would be.