Solving Sudokus with msat
The glamorous world of SAT and SMT solvers is usually preoccupied with proving theorems, searching for software bugs, verifying hardware, and other similarly serious business.
But today, we're going to solve Sudoku grids. My goal is to showcase mSAT, a parametrized SAT solver written in OCaml by my good friend Guillaume Bury and myself. The solver's code can be found on github, but I'm going to detail the most salient parts here.
Demo
First, a demo: write a sudoku grid in this box, as a string of 81 characters, and press "solve" to load the solver in your browser (using js_of_ocaml, which makes it 3~4 times slower than the native version).
A few difficult grids (from this repo):
- ..............3.85..1.2.......5.7.....4...1...9.......5......73..2.1........4...9
- .......12........3..23..4....18....5.6..7.8.......9.....85.....9...4.5..47...6...
- .2..5.7..4..1....68....3...2....8..3.4..2.5.....6...1...2.9.....9......57.4...9..
- ........3..1..56...9..4..7......9.5.7.......8.5.4.2....8..2..9...35..1..6........
- 12.3....435....1....4........54..2..6...7.........8.9...31..5.......9.7.....6...8
- 1.......2.9.4...5...6...7...5.9.3.......7.......85..4.7.....6...3...9.8...2.....1
- .......39.....1..5..3.5.8....8.9...6.7...2...1..4.......9.8..5..2....6..4..7.....
- 12.3.....4.....3....3.5......42..5......8...9.6...5.7...15..2......9..6......7..8
- ..3..6.8....1..2......7...4..9..8.6..3..4...1.7.2.....3....5.....5...6..98.....5.
- 1.......9..67...2..8....4......75.3...5..2....6.3......9....8..6...4...1..25...6.
- ..9...4...7.3...2.8...6...71..8....6....1..7.....56...3....5..1.4.....9...2...7..
- ....9..5..1.....3...23..7....45...7.8.....2.......64...9..1.....8..6......54....7
- 4...3.......6..8..........1....5..9..8....6...7.2........1.27..5.3....4.9........
- 7.8...3.....2.1...5.........4.....263...8.......1...9..9.6....4....7.5...........
- 3.7.4...........918........4.....7.....16.......25..........38..9....5...2.6.....
- ........8..3...4...9..2..6.....79.......612...6.5.2.7...8...5...1.....2.4.5.....3
- .......1.4.........2...........5.4.7..8...3....1.9....3..4..2...5.1........8.6...
- .......12....35......6...7.7.....3.....4..8..1...........12.....8.....4..5....6..
- 1.......2.9.4...5...6...7...5.3.4.......6........58.4...2...6...3...9.8.7.......1
- .....1.2.3...4.5.....6....7..2.....1.8..9..3.4.....8..5....2....9..3.4....67.....
Overview
The code responsible for the Sudoku solving itself fits in roughly 260 lines of (relatively terse) OCaml. Most of the heavy work is delegated to mSAT, which, being a SAT-solver, is very good at exploring large search spaces and pruning branches — exactly what's useful for combinatorial problems such as Sudoku solving.
mSAT is parametrized (using an OCaml functor) by a theory, i.e. a decision procedure that gives additional meaning to the boolean variables the SAT solver manipulates.
In other words, the SAT solver is responsible for finding an assignment
of boolean variables (true or false) that will satisfy a set of constraints.
But in our case, these variables have an additional meaning: they correspond to
parts of the sudoku. For each cell at (x,y)
, there are 9 possible values,
so we create 9 boolean variables whose meaning is (x,y) = i
for each i
.
Then, the SAT solver is going to enumerate boolean assignments of these
variables in an unspecified order; whenever some constraint is violated,
we'll give the SAT solver a conflict (a set of assignments that are incompatible)
to force it to change the assignment.
Let's dive into the code a bit more now.
Cells
First, in the code, we have some boilerplate for cells (as private aliases
for integers between 0 (empty) and 9. The style is a bit verbose but
it's also very robust and forces creation of cells to go through Cell.make
which checks its validity.
module Cell : sig
type t = private int
val equal : t -> t -> bool
val neq : t -> t -> bool
val hash : t -> int
val empty : t
val is_empty : t -> bool
val is_full : t -> bool
val make : int -> t
val pp : t Fmt.printer
end = struct
type t = int
let empty = 0
let[@inline] make i = assert (i >= 0 && i <= 9); i
let[@inline] is_empty x = x = 0
let[@inline] is_full x = x > 0
let hash = CCHash.int
let[@inline] equal (a:t) b = a=b
let[@inline] neq (a:t) b = a<>b
let pp out i = if i=0 then Fmt.char out '.' else Fmt.int out i
end
So far, nothing particularly interesting. Let's look at grids.
Grids
Grids are a bit more interesting, for several reasons.
First, I wrote this sudoku solver as a mean to test mSAT's API without writing thousands of lines of code. It means that the Sudoku solver is going to check the solution, and fail if the API had a bug that lead to an invalid solution.
Second, constraints over lines, columns, and squares, are very redundant.
To cut through the repetitions I use sequence,
a very fast iterator library for OCaml
(see the slides of old talk at OUPS 2014).
This makes the code a lot more compact, but also a bit harder to understand.
In a nutshell, a value of type 'a Sequence.t
is a series of values of type 'a
,
represented by a function ('a -> unit) -> unit
(an iter function).
Each set of cells that must be distinct (each column, row, and 3×3 square)
is represented as a sequence of cells. Then we just have to assert that
these sequences don't contain duplicates (see all_distinct
).
The function matches
checks that a grid with some undefined cells
matches a fully defined grid (i.e, so as to check that the solution returned by the SAT solver
is actually a solution of the initial grid, on top of being valid).
Finally we have a pretty-printer and a parser for the 81-chars representation of a grid.
module Grid : sig
type t
val get : t -> int -> int -> Cell.t
val set : t -> int -> int -> Cell.t -> t
(** A set of related cells *)
type set = (int*int*Cell.t) Sequence.t
val rows : t -> set Sequence.t
val cols : t -> set Sequence.t
val squares : t -> set Sequence.t
val all_cells : t -> (int*int*Cell.t) Sequence.t
val parse : string -> t
val is_full : t -> bool
val is_valid : t -> bool
val matches : pat:t -> t -> bool
val pp : t Fmt.printer
end = struct
type t = Cell.t array
let[@inline] get (s:t) i j = s.(i*9 + j)
let[@inline] set (s:t) i j n =
let s' = Array.copy s in
s'.(i*9 + j) <- n;
s'
(** A set of related cells, with their positions *)
type set = (int*int*Cell.t) Sequence.t
open Sequence.Infix
let all_cells (g:t) =
0 -- 8 >>= fun i ->
0 -- 8 >|= fun j -> (i,j,get g i j)
let rows (g:t) =
0 -- 8 >|= fun i ->
( 0 -- 8 >|= fun j -> (i,j,get g i j))
let cols g =
0 -- 8 >|= fun j ->
( 0 -- 8 >|= fun i -> (i,j,get g i j))
let squares g =
0 -- 2 >>= fun sq_i ->
0 -- 2 >|= fun sq_j ->
( 0 -- 2 >>= fun off_i ->
0 -- 2 >|= fun off_j ->
let i = 3*sq_i + off_i in
let j = 3*sq_j + off_j in
(i,j,get g i j))
let is_full g = Array.for_all Cell.is_full g
(* does the grid satisfy the unicity constraints? *)
let is_valid g =
let all_distinct (s:set) =
(s >|= fun (_,_,c) -> c)
|> Sequence.diagonal
|> Sequence.for_all (fun (c1,c2) -> Cell.neq c1 c2)
in
Sequence.for_all all_distinct @@ rows g &&
Sequence.for_all all_distinct @@ cols g &&
Sequence.for_all all_distinct @@ squares g
(* does [g2] correspond to [g1] wherever [g1] is defined? *)
let matches ~pat:g1 g2 : bool =
all_cells g1
|> Sequence.filter (fun (_,_,c) -> Cell.is_full c)
|> Sequence.for_all (fun (x,y,c) -> Cell.equal c @@ get g2 x y)
let pp out g =
Fmt.fprintf out "@[<v>";
Array.iteri
(fun i n ->
Cell.pp out n;
if i mod 9 = 8 then Fmt.fprintf out "@,")
g;
Fmt.fprintf out "@]"
(* parse a grid represented by 81 chars *)
let parse (s:string) : t =
if String.length s < 81 then (
errorf "line is too short, expected 81 chars, not %d" (String.length s);
);
let a = Array.make 81 Cell.empty in
for i = 0 to 80 do
let c = String.get s i in
let n = if c = '.' then 0 else Char.code c - Char.code '0' in
if n < 0 || n > 9 then errorf "invalid char %c" c;
a.(i) <- Cell.make n
done;
a
end
Backtracking
Now for something completely different:
module B_ref = Msat_backtrack.Ref
This is a handy alias to a sub-library of mSAT, which provides a backtrackable reference. This is going to be helpful to maintain the representation of the current grid as we follow the SAT solver's exploration of the search space.
For reference, the API of this "backtrackable reference" is:
module B_ref : sig
type 'a t
val create : ?copy:('a -> 'a) -> 'a -> 'a t
(** Create a backtrackable reference holding the given value initially.
@param copy if provided, will be used to copy the value when [push_level]
is called. *)
val set : 'a t -> 'a -> unit
(** Set the reference's current content *)
val get : 'a t -> 'a
(** Get the reference's current content *)
val update : 'a t -> ('a -> 'a) -> unit
(** Update the reference's current content *)
val push_level : _ t -> unit
(** Push a backtracking level, copying the current value on top of some
stack. The [copy] function will be used if it was provided in {!create}. *)
val n_levels : _ t -> int
(** Number of saved values *)
val pop_levels : _ t -> int -> unit
(** Pop [n] levels, restoring to the value the reference was storing [n] calls
to [push_level] earlier.
@raise Invalid_argument if [n] is bigger than [n_levels]. *)
end
Basically, you can update the reference, but you can also push_level
(which creates
a backtracking point)
and pop_levels
(which restores the reference to the state it had
n
backtracking points earlier).
The Solver
And now, the main dish: the solver itself! I'm going to cut it into several parts to explain better.
module Solver : sig
type t
val create : Grid.t -> t
val solve : t -> Grid.t option
end = struct
…
Well the API is quite simple. A type Solver.t
, a function to create it
with the initial (partial) grid, and a solve
function which returns
a solution if there's one.
…
open Msat.Solver_intf
(* formulas *)
module F = struct
type t = bool*int*int*Cell.t
let equal (sign1,x1,y1,c1)(sign2,x2,y2,c2) =
sign1=sign2 && x1=x2 && y1=y2 && Cell.equal c1 c2
let hash (sign,x,y,c) = CCHash.(combine4 (bool sign)(int x)(int y)(Cell.hash c))
let pp out (sign,x,y,c) =
Fmt.fprintf out "[@[(%d,%d) %s %a@]]" x y (if sign then "=" else "!=") Cell.pp c
(* negation: just flip the sign *)
let neg (sign,x,y,c) = (not sign,x,y,c)
let norm ((sign,_,_,_) as f) =
if sign then f, Same_sign else neg f, Negated
let make sign x y (c:Cell.t) : t = (sign,x,y,c)
end
…
Ah yes, this defines a notion of formulas (the boolean variables mentioned
earlier).
A formula, an atom of truth that the SAT solver is going to manipulate
and to assign to true or false, is here a tuple (bool*int*int*Cell.t)
.
A value (sign, x, y, c)
is the formula (x,y) = c
(or (x,y) != c
if sign=false
).
We add a sign to it because it makes negation easy.
Eventually, once search has terminated successfully, we will have a collection
of true formulas (and one of false formulas, which are of little interest here).
These true formulas will describe the complete state of the grid.
Note that mSAT requires the atomic formulas to be comparable and hashable, cheaply, if possible. It maintains an internal hash table to map these formulas to its own internal representation of boolean variables.
…
module Theory = struct
type proof = unit
module Formula = F
type t = {
grid: Grid.t B_ref.t;
}
let create g : t = {grid=B_ref.create g}
let[@inline] grid self : Grid.t = B_ref.get self.grid
let[@inline] set_grid self g : unit = B_ref.set self.grid g
let push_level self = B_ref.push_level self.grid
let pop_levels self n = B_ref.pop_levels self.grid n
let pp_c_ = Fmt.(list ~sep:(return "@ ∨ ")) F.pp
let[@inline] logs_conflict kind c : unit =
Log.debugf 4 (fun k->k "(@[conflict.%s@ %a@])" kind pp_c_ c)
(* check that all cells are full *)
let check_full_ (self:t) acts : unit =
Grid.all_cells (grid self)
(fun (x,y,c) ->
if Cell.is_empty c then (
let c =
CCList.init 9
(fun c -> F.make true x y (Cell.make (c+1)))
in
Log.debugf 4 (fun k->k "(@[add-clause@ %a@])" pp_c_ c);
acts.acts_add_clause ~keep:true c ();
))
(* check constraints *)
let check_ (self:t) acts : unit =
Log.debugf 4 (fun k->k "(@[sudoku.check@ @[:g %a@]@])" Grid.pp (B_ref.get self.grid));
let[@inline] all_diff kind f =
let pairs =
f (grid self)
|> Sequence.flat_map
(fun set ->
set
|> Sequence.filter (fun (_,_,c) -> Cell.is_full c)
|> Sequence.diagonal)
in
pairs
(fun ((x1,y1,c1),(x2,y2,c2)) ->
if Cell.equal c1 c2 then (
assert (x1<>x2 || y1<>y2);
let c = [F.make false x1 y1 c1; F.make false x2 y2 c2] in
logs_conflict ("all-diff." ^ kind) c;
acts.acts_raise_conflict c ()
))
in
all_diff "rows" Grid.rows;
all_diff "cols" Grid.cols;
all_diff "squares" Grid.squares;
()
let trail_ (acts:_ Msat.acts) =
acts.acts_iter_assumptions
|> Sequence.map
(function
| Assign _ -> assert false
| Lit f -> f)
(* update current grid with the given slice *)
let add_slice (self:t) acts : unit =
trail_ acts
(function
| false,_,_,_ -> ()
| true,x,y,c ->
assert (Cell.is_full c);
let grid = grid self in
let c' = Grid.get grid x y in
if Cell.is_empty c' then (
set_grid self (Grid.set grid x y c);
) else if Cell.neq c c' then (
(* conflict: at most one value *)
let c = [F.make false x y c; F.make false x y c'] in
logs_conflict "at-most-one" c;
acts.acts_raise_conflict c ()
)
)
let partial_check (self:t) acts : unit =
Log.debugf 4
(fun k->k "(@[sudoku.partial-check@ :trail [@[%a@]]@])" (Fmt.seq F.pp) (trail_ acts));
add_slice self acts;
check_ self acts
let final_check (self:t) acts : unit =
Log.debugf 4 (fun k->k "(@[sudoku.final-check@])");
check_full_ self acts;
check_ self acts
end
…
This is the theory, the core of the reasoning engine. It interacts with the SAT solver's partial models (candidate assignments that might be solutions… or not).
There are two entry points here:
-
partial_check
is called with assignments of a subset of the formulas, during the search. It means it's called very often, so it better be fast. Its job is to reject assignments that are obviously wrong.The call to
add_slice
updates the current model of the grid (which lives in a backtrackable reference, as sometimes the SAT solver undoes its previous choices) with the new decisions the SAT solver made. This might fail if a cell is assigned to two distinct values (look for "at-most-one" in the code); a clause(x,y) != i OR (x,y) != j
is added, and the model is rejected.Otherwise,
check_
is called to verify that the row/column/square constraints are respected. If some constraint is not respected (say(1,3)=7
and(1,6)=7
, which means column 6 contains7
twice), a clause is added to reject it:(1,3) != 7 OR (1,6) != 7
; the model is rejected, the solver backtracks, and search resumes. -
final_check
is called when mSAT has a full model (all the formulas are true or false). If our Theory accepts this model, the search ends, and we can decode the bunch of true formulas into an actual grid. Otherwise, it means there is some validity issue, and we must raise a conflict, ie. reject the assignment. This can be more costly but it has to check the model fully, as it's potentially the last chance to reject an invalid model.In addition to calling
check_
like inpartial_check
,final_check
also verifies that all cells are assigned (check_full_
). If it's not the case, there's a cell(x,y)
which is unassigned, and a clause(x,y)=1 OR (x,y)=2 OR … OR (x,y)=9
is created and added to the SAT solver. Search then resumes.
These two functions are the heart and soul of the theory. In a proper SMT solver (such as Z3, CVC4, Yices2, etc.), the hundreds of thousands of lines of C or C++ are mostly dedicated to this part, and the SAT solver is much smaller.
Finally, the wrapper code:
…
module S = Msat.Make_cdcl_t(Theory)
type t = {
grid0: Grid.t;
solver: S.t;
}
let solve (self:t) : _ option =
let assumptions =
Grid.all_cells self.grid0
|> Sequence.filter (fun (_,_,c) -> Cell.is_full c)
|> Sequence.map (fun (x,y,c) -> F.make true x y c)
|> Sequence.map (S.make_atom self.solver)
|> Sequence.to_rev_list
in
Log.debugf 2
(fun k->k "(@[sudoku.solve@ :assumptions %a@])" (Fmt.Dump.list S.Atom.pp) assumptions);
let r =
match S.solve self.solver ~assumptions with
| S.Sat _ -> Some (Theory.grid (S.theory self.solver))
| S.Unsat _ -> None
in
r
let create g : t =
{ solver=S.create ~store_proof:false (Theory.create g); grid0=g }
end
We instantiate mSAT's functor over the Theory
using Msat.Make_cdcl_t
,
and wrap the S.solve
function by asserting formulas that correspond to
the initial grid (i.e. if the initial grid contains 5 at position (1,2)
,
we assume (1,2)=5
to be true from the beginning).
If the solver returns SAT
, we return the current Grid, since it should be full
and valid.
Conclusion
This solver is not the most efficient. In practice, for Sudoku, it's faster and simpler to encode the whole problem into SAT from the start ("bit-blasting") and call minisat or some other state of the art SAT-solver.
But this way is more fun, and more flexible. I hope it demonstrates that writing
small CDCL(T) solvers (where T
is whatever you want it to be, cough)
is not that hard. It also shows the abstraction power of OCaml's functors.