Friday, April 29, 2011

Logic programming in Scala, part 2: backtracking

In the previous post we saw how to write computations in a logic monad, where a “value” is a choice among alternatives, and operating on a value means operating on all the alternatives.

Our first implementation of the logic monad represents a choice among alternatives as a list, and operating on a value means running the operation for each alternative immediately (to produce a new list of alternatives). If we imagine alternatives as leaves of a tree (with | indicating branching), the first implementation explores the tree breadth-first.

This is OK for some problems, but we run into trouble when there are a large or infinite number of alternatives. For example, a choice among the natural numbers:

scala> import LogicList._ 
import LogicList._ 
 
scala> val nat: T[Int] = unit(0) | nat.map(_ + 1) 
java.lang.NullPointerException 
        ... 

This goes wrong because even though the right-hand argument to | is by-name, we immediately try to use it, and fail because nat is not yet defined.

scala> def nat: T[Int] = unit(0) | nat.map(_ + 1) 
nat: LogicList.T[Int] 
scala> run(nat, 10) 
java.lang.StackOverflowError 
        ... 

With def we can successfully define nat, because the right-hand side isn’t evaluated until nat is used in the call to run, but we overflow the stack trying to compute all the natural numbers.

Let’s repair this with a fancier implementation of the logic monad, translated from Kiselyov et al.’s Backtracking, Interleaving, and Terminating Monad Transformers. This implementation will explore the tree depth-first.

Success and failure continuations

The idea is to represent a choice of alternatives by a function, which takes as arguments two functions: a success continuation and a failure continuation. The success continuation is just a function indicating what to do next with each alternative; the failure continuation is what to do next when there are no more alternatives.

For success, what we do next is either return the alternative (when we have reached a leaf of the tree), or perform some operation on it (possibly forming new branches rooted at the alternative). For failure, what we do next is back up to the last branch point and succeed with the next alternative. If there are no more alternatives at the previous branch point we back up again, and so on until we can succeed or finally run out of alternatives. In other words, we do depth-first search on the tree, except that the tree isn’t a materialized data structure—it’s created on the fly.

(In the jargon of logic programming, a branch point is called a “choice point”, and going back to an earlier choice point is called “backtracking”.)

object LogicSFK extends Logic { 
  type FK[R] = () => R 
  type SK[A,R] = (A, FK[R]) => R 
 
  trait T[A] { def apply[R](sk: SK[A,R], fk: FK[R]): R } 

The continuations can return a result of some arbitrary type R. This means that the function representing a choice has a “rank-2” polymorphic type—it takes functions which are themselves polymorphic—which is not directly representable in Scala. But we can encode it by making the representation function a method on a trait.

The success continuation takes a value of the underlying type (i.e. an alternative), and also a failure continuation, to call in case this branch of the tree eventually fails (by calling fail, or filter when no alternative satisfies the predicate). The failure continuation is also called to succeed with the next alternative after returning a leaf (see split).

  def fail[A] = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = fk() 
    } 
 
  def unit[A](a: A) = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = sk(a, fk) 
    } 

To fail, just call the failure continuation. To succeed with one alternative, call the success continuation with the single alternative and the passed-in failure continuation—there are no more alternatives to try, so if this branch fails the unit fails.

  def or[A](t1: T[A], t2: => T[A]) = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = 
        t1(sk, { () => t2(sk, fk) }) 
    } 

Or creates a choice point. We want to explore the alternatives in both t1 and t2, so we pass the success continuation to t1 (which calls it on each alternative); when t1 is exhausted we pass the success continuation to t2; finally we fail with the caller’s failure continuation—that is, we backtrack.

  def bind[A,B](t: T[A], f: A => T[B]) = 
    new T[B] { 
      def apply[R](sk: SK[B,R], fk: FK[R]) = 
        t(({ (a, fk) => f(a)(sk, fk) }: SK[A,R]), fk) 
    } 
 
  def apply[A,B](t: T[A], f: A => B) = 
    new T[B] { 
      def apply[R](sk: SK[B,R], fk: FK[R]) = 
        t(({ (a, fk) => sk(f(a), fk) }: SK[A,R]), fk) 
    } 

For bind we extend each branch by calling f on the current leaf. To succeed we call f on the alternative a. Now f(a) returns a choice of alternatives, so we pass it the original success continuation (which says what to do next with alternatives resulting from the bind), and the failure continuation in force at the point a was generated (which succeeds with the next available alternative from f(a)).

For apply things are simpler, since f(a) returns a single value rather than a choice of alternatives: we succeed immediately with the returned value.

  def filter[A](t: T[A], p: A => Boolean) = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = { 
        val sk2: SK[A,R] = 
          { (a, fk) => if (p(a)) sk(a, fk) else fk() } 
        t(sk2, fk) 
      } 
    } 

To filter a choice of alternatives, each time we succeed with a value we see if it satisfies the predicate p; if it does, we succeed with that value (extending the branch), otherwise we fail (pruning the branch).

  def split[A](t: T[A]) = { 
    def unsplit(fk: FK[Option[(A,T[A])]]): T[A] = 
      fk() match { 
        case None => fail 
        case Some((a, t)) => or(unit(a), t) 
      } 
    def sk : SK[A,Option[(A,T[A])]] = 
      { (a, fk) => Some((a, bind(unit(fk), unsplit))) } 
    t(sk, { () => None }) 
  } 
} 

The point of split is to pull a single alternative from a choice, returning along with it a choice of the remaining alternatives. In the list implementation we just returned the head and tail of the list. In this implementation, the alternatives are computed on demand; we want to be careful to do only as much computation as needed to pull the first alternative

The failure continuation we pass to t just returns None when there are no more alternatives. The success continuation sk returns the first alternative and a choice of the remaining alternatives (wrapped in Some).

The tricky part is the choice of remaining alternatives. We’re given the failure continuation fk; calling it calls sk on the next alternative, which ultimately returns Some(a, t) where a is the next alternative, or None if there are no more alternatives. We repackage this Option as a choice of alternatives with unsplit. So that we don’t call fk too soon, we call unsplit via bind, which defers the call until the resulting choice of alternatives is actually used.

Now we can write infinite choices:

scala> import LogicSFK._ 
import LogicSFK._ 
 
scala> val nat: T[Int] = unit(0) | nat.map(_ + 1) 
nat: LogicSFK.T[Int] = LogicSFK$$anon$3@27aea0c1 
 
scala> run(nat, 10) 
res1: List[Int] = List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) 

Well, this is a pretty complicated way to generate the natural numbers up to 10…

While nat looks like a lazy stream (as you might write in Haskell), no results are memoized (as they are in Haskell). To compute each successive number all the previous ones must be recomputed, and the running time of run(nat, N) is O(N2).

Defunctionalization

The code above is a fairly direct translation of the Haskell code from the paper. But its use of continuation-passing style doesn’t map well to Scala, because Scala doesn’t implement tail-call elimination (because the JVM doesn’t). Every call to a success or failure continuation adds a frame to the stack, even though all we ever do with the result is return it (i.e. the call is in tail-position), so the stack frame could be eliminated.

Surprisingly, we run out of memory before we run out of stack:

scala> run(nat, 2000) 
java.lang.OutOfMemoryError: Java heap space 
 ... 

A little heap profiling shows that we’re using quadratic space as well as quadratic time. It turns out that the implementation of Logic.run (from the previous post) has a space leak. The call to run is not tail-recursive, so the stack frame hangs around, and although t is dead after split(t), there’s still a reference to it on the stack.

We can rewrite run with an accumulator to be tail-recursive:

  def run[A](t: T[A], n: Int): List[A] = { 
    def runAcc(t: T[A], n: Int, acc: List[A]): List[A] = 
      if (n <= 0) acc.reverse else 
        split(t) match { 
          case None => acc.reverse 
          case Some((a, t)) => runAcc(t, n - 1, a :: acc) 
        } 
    runAcc(t, n, Nil) 
  } 

Now scalac compiles runAcc as a loop, so there are no stack frames holding on to dead values of t, and we get the expected:

scala> run(nat, 9000) 
java.lang.StackOverflowError 
 ... 

To address the stack overflow we turn to defunctionalization. The idea (from John Reynold’s classic paper Definitional Interpreters for Higher-Order Programming Languages) is to replace functions and their applications with data constructors (we’ll use case classes) and an apply function, which matches the data constructor and does whatever the corresponding function body does. If a function captures variables, the data constructor must capture the same variables.

After defunctionalization we’re left with three mutually recursive apply functions (one for each of T, FK, and SK) where each recursive call is in tail position. In theory the compiler could transform these into code that takes only constant stack space (since they are local functions private to split). But in fact it will do so only for single recursive functions, so we will need to do this transformation by hand.

There is one hitch: the original code is not completely tail-recursive, because of unsplit, which calls a failure continuation then matches on the result. To fix this we need to add yet another continuation, which represents what to do after returning a result from a success or failure continuation.

object LogicSFKDefunc extends Logic { 
  type O[A] = Option[(A,T[A])] 
 
  sealed trait T[A] 
  case class Fail[A]() extends T[A] 
  case class Unit[A](a: A) extends T[A] 
  case class Or[A](t1: T[A], t2: () => T[A]) extends T[A] 
  case class Bind[A,B](t: T[A], f: A => T[B]) extends T[B] 
  case class Apply[A,B](t: T[A], f: A => B) extends T[B] 
  case class Filter[A](t: T[A], p: A => Boolean) extends T[A] 
  case class Unsplit[A](fk: FK[O[A]]) extends T[A] 
 
  def fail[A] = Fail() 
  def unit[A](a: A) = Unit(a) 
  def or[A](t1: T[A], t2: => T[A]) = Or(t1, { () => t2 }) 
  def bind[A,B](t: T[A], f: A => T[B]) = Bind(t, f) 
  def apply[A,B](t: T[A], f: A => B) = Apply(t, f) 
  def filter[A](t: T[A], p: A => Boolean) = Filter(t, p) 

A choice of alternatives T[A] is now represented symbolically by case classes, and the functions which operate on choices just return the corresponding case. The cases capture the same variables that were captured in the original functions.

We have an additional case Unsplit which represents the bind(unit(fk), unsplit) combination from split. And we use O[A] as a convenient abbreviation.

  sealed trait FK[R] 
  case class FKOr[A,R](t: () => T[A], sk: SK[A,R], fk: FK[R]) 
    extends FK[R] 
  case class FKSplit[R](r: R) extends FK[R] 
 
  sealed trait SK[A,R] 
  case class SKBind[A,B,R](f: A => T[B], sk: SK[B,R]) 
    extends SK[A,R] 
  case class SKApply[A,B,R](f: A => B, sk: SK[B,R]) 
    extends SK[A,R] 
  case class SKFilter[A,R](p: A => Boolean, sk: SK[A,R]) 
    extends SK[A,R] 
  case class SKSplit[A,R](r: (A, FK[R]) => R) extends SK[A,R] 
 
  sealed trait K[R,R2] 
  case class KReturn[R]() extends K[R,R] 
  case class KUnsplit[A,R,R2](sk: SK[A,R], fk:FK[R], k: K[R,R2]) 
    extends K[O[A],R2] 

Each case for FK (respectively SK) corresponds to a success (respectively failure) continuation function in the original code—it’s easy to match them up.

The K cases are for the new return continuation. They are defunctionalized from functions R => R2; we can either return a value directly, or match on whether it is Some or None as in unsplit. (If K is hard to understand you might try “refunctionalizing” it by replacing the cases with functions.)

We see that case classes are more powerful than variants in OCaml, without GADTs at least. Cases can have “input” type variables (appearing in arguments) which do not appear in the “output” (the type the case extends). When we match on the case these are treated as existentials. And the output type of a case can be more restrictive than type it extends; when we match on the case we can make more restrictive assumptions about types in that branch of the match. More on this in Emir, Odersky, and Williams’ Matching Objects with Patterns.

  def split[A](t: T[A]) = { 
 
    def applyT[A,R,R2] 
      (t: T[A], sk: SK[A,R], fk: FK[R], k: K[R,R2]): R2 = 
      t match { 
        case Fail() => applyFK(fk, k) 
        case Unit(a) => applySK(sk, a, fk, k) 
        case Or(t1, t2) => applyT(t1, sk, FKOr(t2, sk, fk), k) 
        case Bind(t, f) => applyT(t, SKBind(f, sk), fk, k) 
        case Apply(t, f) => applyT(t, SKApply(f, sk), fk, k) 
        case Filter(t, p) => applyT(t, SKFilter(p, sk), fk, k) 
        case Unsplit(fk2) => applyFK(fk2, KUnsplit(sk, fk, k)) 
      } 
 
    def applyFK[R,R2](fk: FK[R], k: K[R,R2]): R2 = 
      fk match { 
        case FKOr(t, sk, fk) => applyT(t(), sk, fk, k) 
        case FKSplit(r) => applyK(k, r) 
      } 
 
    def applySK[A,R,R2] 
      (sk: SK[A,R], a: A, fk: FK[R], k: K[R,R2]): R2 = 
      sk match { 
        case SKBind(f, sk) => applyT(f(a), sk, fk, k) 
        case SKApply(f, sk) => applySK(sk, f(a), fk, k) 
        case SKFilter(p, sk) => 
          if (p(a)) applySK(sk, a, fk, k) else applyFK(fk, k) 
        case SKSplit(rf) => applyK(k, rf(a, fk)) 
      } 

Again, each of these cases corresponds directly to a function in the original code, and again it is easy to match them up (modulo the extra return continuation argument) to see that all we have done is separated the data part of the function (i.e. the captured variables) from the code part.

The exception is Unsplit, which again corresponds to bind(unit(fk), unsplit). To apply it, we apply fk (which collapses unit(fk), bind, and the application of fk in unsplit) with KUnsplit as continuation, capturing sk, fk, and k (corresponding to their capture in the success continuation of bind).

    def applyK[R,R2](k: K[R,R2], r: R): R2 = 
      k match { 
        case KReturn() => r.asInstanceOf[R2] 
        case KUnsplit(sk, fk, k) => { 
          r match { 
            case None => applyFK(fk, k) 
            case Some((a, t)) => applyT(or(unit(a), t), sk, fk, k) 
          } 
        } 
      } 

For KReturn we just return the result. Although KReturn extends K[R,R], Scala doesn’t deduce from this that R = R2, so we must coerce the result. For KUnsplit we do the same match as unsplit, then apply the resulting T (for the None case we call the failure continuation directly instead of applying fail). Here Scala deduces from the return type of KUnsplit that is safe to treat r as an Option.

    applyT[A,O[A],O[A]]( 
      t, 
      SKSplit((a, fk) => Some((a, Unsplit(fk)))), 
      FKSplit(None), 
      KReturn()) 
  } 
} 

Finally we apply the input T in correspondence to the original split.

Tail call elimination

(This section has been revised; you can see the original here.)

To eliminate the stack frames from tail calls, we next rewrite the four mutually-recursive functions into a single recursive function (which Scala compiles as a loop). To do this we have to abandon some type safety (but only in the implementation of the Logic monad; we’ll still present the same safe interface).

object LogicSFKDefuncTailrec extends Logic { 
  type O[A] = Option[(A,T[A])] 
 
  type T[A] = I 
 
  sealed trait I 
  case class Fail() extends I 
  case class Unit(a: Any) extends I 
  case class Or(t1: I, t2: () => I) extends I 
  case class Bind(t: I, f: Any => I) extends I 
  case class Apply(t: I, f: Any => Any) extends I 
  case class Filter(t: I, p: Any => Boolean) extends I 
  case class Unsplit(fk: I) extends I 
 
  case class FKOr(t: () => I, sk: I, fk: I) extends I 
  case class FKSplit(r: O[Any]) extends I 
 
  case class SKBind(f: Any => I, sk: I) extends I 
  case class SKApply(f: Any => Any, sk: I) extends I 
  case class SKFilter(p: Any => Boolean, sk: I) extends I 
  case class SKSplit(r: (Any, I) => O[Any]) extends I 
 
  case object KReturn extends I 
  case class KUnsplit(sk: I, fk: I, k: I) extends I 

This is all pretty much as before except that we erase all the type parameters. Having done so we can combine the four defunctionalized types into a single type I (for “instruction” perhaps), which will allow us to write a single recursive apply function. The type parameter in T[A] is then a phantom type since it does not appear on the right-hand side of the definition; it is used only to enforce constraints outside the module.

  def fail[A]: T[A] = Fail() 
  def unit[A](a: A): T[A] = Unit(a) 
  def or[A](t1: T[A], t2: => T[A]): T[A] = Or(t1, { () => t2 }) 
  def bind[A,B](t: T[A], f: A => T[B]): T[B] = 
    Bind(t, f.asInstanceOf[Any => I]) 
  def apply[A,B](t: T[A], f: A => B): T[B] = 
    Apply(t, f.asInstanceOf[Any => I]) 
  def filter[A](t: T[A], p: A => Boolean): T[A] = 
    Filter(t, p.asInstanceOf[Any => Boolean]) 

The functions for building T[A] values are mostly the same. We have to cast passed-in functions since Any is not a subtype of arbitrary A. The return type annotations don’t seem necessary but I saw some strange type errors without them (possibly related to the phantom type?) when using the Logic.Syntax wrapper.

def split[A](t: T[A]): O[A] = { 
  def apply(i: I, a: Any, r: O[Any], sk: I, fk: I, k: I): O[Any] = 
    i match { 
      case Fail() => apply(fk, null, null, null, null, k) 
      case Unit(a) => apply(sk, a, null, null, fk, k) 
      case Or(t1, t2) => 
        apply(t1, null, null, sk, FKOr(t2, sk, fk), k) 
      case Bind(t, f) => 
        apply(t, null, null, SKBind(f, sk), fk, k) 
      case Apply(t, f) => 
        apply(t, null, null, SKApply(f, sk), fk, k) 
      case Filter(t, p) => 
        apply(t, null, null, SKFilter(p, sk), fk, k) 
      case Unsplit(fk2) => 
        apply(fk2, null, null, null, null, KUnsplit(sk, fk, k)) 
 
      case FKOr(t, sk, fk) => apply(t(), null, null, sk, fk, k) 
      case FKSplit(r) => apply(k, null, r, null, null, null) 
 
      case SKBind(f, sk) => apply(f(a), null, null, sk, fk, k) 
      case SKApply(f, sk) => apply(sk, f(a), null, null, fk, k) 
      case SKFilter(p, sk) => 
        if (p(a)) 
          apply(sk, a, null, null, fk, k) 
        else 
          apply(fk, null, null, null, null, k) 
      case SKSplit(rf) => 
        apply(k, null, rf(a, fk), null, null, null) 
 
      case KReturn => r 
      case KUnsplit(sk, fk, k) => { 
        r match { 
          case None => apply(fk, null, null, null, null, k) 
          case Some((a, t)) => 
            apply(or(unit(a), t), null, null, sk, fk, k) 
        } 
      } 
    } 
 
  apply(t, 
        null, 
        null, 
        SKSplit((a, fk) => Some((a, Unsplit(fk)))), 
        FKSplit(None), 
        KReturn).asInstanceOf[O[A]] 
} 

The original functions took varying arguments; the single function takes all the arguments which the original ones did. We pass null for unused arguments in each call, but otherwise the cases are the same as before.

Now we can evaluate nat to large N without running out of stack (but since the running time is quadratic it takes longer than I care to wait to complete):

scala> run(nat, 100000) 
^C 

See the complete code here.

Next time we’ll thread state through this backtracking logic monad, and use it to implement unification.

Wednesday, April 6, 2011

Logic programming in Scala, part 1

I got a new job where I am hacking some Scala. I thought I would learn something by translating some functional code into Scala, and a friend had recently pointed me to Kiselyov et al.’s Backtracking, Interleaving, and Terminating Monad Transformers, which provides a foundation for Prolog-style logic programming. Of course, a good translation should use the local idiom. So in this post (and the next) I want to explore an embedded domain-specific language for logic programming in Scala.

A search problem

Here is a problem I sometimes give in interviews:

Four people need to cross a rickety bridge, which can hold only two people at a time. It’s a moonless night, so they need a light to cross; they have one flashlight with a battery which lasts 60 minutes. Each person crosses the bridge at a different speed: Alice takes 5 minutes, Bob takes 10, Candace takes 20 minutes, and Dave 25. How do they get across?

I’m not interested in the answer—I’m interviewing programmers, not law school applicants—but rather in how to write a program to find the answer.

The basic shape of the solution is to represent the state of the world (where are the people, where is the flashlight, how much battery is left), write a function to compute from any particular state the set of possible next states, then search for an answer (a path from the start state to the final state) in the tree formed by applying the next state function transitively to the start state. (Here is a paper describing solutions in Prolog and Haskell.)

Here is a first solution in Scala:

object Bridge0 { 
  object Person extends Enumeration { 
    type Person = Value 
    val Alice, Bob, Candace, Dave = Value 
    val all = List(Alice, Bob, Candace, Dave) // values is broken 
  } 
  import Person._ 
 
  val times = Map(Alice -> 5, Bob -> 10, Candace -> 20, Dave -> 25) 
 
  case class State(left: List[Person], 
                   lightOnLeft: Boolean, 
                   timeRemaining: Int) 

We define an enumeration of people (the Enumeration class is a bit broken in Scala 2.8.1), a map of the time each takes to cross, and a case class to store the state of the world: the list of people on the left side of the bridge (the right side is just the complement); whether the flashlight is on the left; and how much time remains in the flashlight.

  def chooseTwo(list: List[Person]): List[(Person,Person)] = { 
    val init: List[(Person, Person)] = Nil 
    list.foldLeft(init) { (pairs, p1) => 
      list.foldLeft(pairs) { (pairs, p2) => 
        if (p1 < p2) (p1, p2) :: pairs else pairs 
      } 
    } 
  } 

This function returns the list of pairs of people from the input list. We use foldLeft to do a double loop over the input list, accumulating pairs (p1, p2) where p1 < p2; this avoids returning (Alice, Bob) and also (Bob, Alice). The use of foldLeft is rather OCamlish, and if you know Scala you will complain that foldLeft is not idiomatic—we will repair this shortly.

In Scala, Nil doesn’t have type 'a list like in OCaml and Haskell, but rather List[Nothing]. The way local type inference works, the type variable in the type of foldLeft is instantiated with the type of the init argument, so you have to ascribe a type to init (or explicitly instantiate the type variable with foldLeft[List[(Person, Person)]]) or else you get a type clash between List[Nothing] and List[(Person, Person)].

  def next(state: State): List[State] = { 
    if (state.lightOnLeft) { 
      val init: List[State] = Nil 
      chooseTwo(state.left).foldLeft(init) { 
        case (states, (p1, p2)) => 
          val timeRemaining = 
            state.timeRemaining - math.max(times(p1), times(p2)) 
          if (timeRemaining >= 0) { 
            val left = 
              state.left.filterNot { p => p == p1 || p == p2 } 
            State(left, false, timeRemaining) :: states 
          } 
          else 
            states 
      } 
    } else { 
      val right = Person.all.filterNot(state.left.contains) 
      val init: List[State] = Nil 
      right.foldLeft(init) { (states, p) => 
        val timeRemaining = state.timeRemaining - times(p) 
        if (timeRemaining >= 0) 
          State(p :: state.left, true, timeRemaining) :: states 
        else 
          states 
      } 
    } 
  } 

Here we compute the set of successor states for a state. We make a heuristic simplification: when the flashlight is on the left (the side where everyone begins) we move two people from the left to the right; when it is on the right we move only one. I don’t have a proof that an answer must take this form, but I believe it, and it makes the code shorter.

So when the light is on the left we fold over all the pairs of people still on the left, compute the time remaining if they were to cross, and if it is not negative build a new state where they and the flashlight are moved to the right and the time remaining updated.

If the light is on the right we do the same in reverse, but choose only one person to move.

  def tree(path: List[State]): List[List[State]] = 
    next(path.head). 
      map(s => tree(s :: path)). 
        foldLeft(List(path)) { _ ++ _ } 
 
  def search: List[List[State]] = { 
    val start = List(State(Person.all, true, 60)) 
    tree(start).filter { _.head.left == Nil } 
  } 
} 

A list of successive states is a path (with the starting state at the end and the most recent state at the beginning); the state tree is a set of paths. The tree rooted at a path is the set of paths with the input path as a suffix. To compute this tree, we find the successor states of the head of the path, augment the path with each state in turn, recursively find the tree rooted at each augmented path, then append them all (including the input path).

Then to find an answer, we generate the state tree rooted at the path consisting only of the start state (everybody and the flashlight on the left, 60 minutes remaining on the light), then filter out the paths which end in a final state (everybody on the right).

For-comprehensions

To make the code above more idiomatic Scala (and more readable), we would of course use for-comprehensions, for example:

  def chooseTwo(list: List[Person]): List[(Person,Person)] = 
    for { p1 <- list; p2 <- list; if p1 < p2 } yield (p1, p2) 

Just as before, we do a double loop over the input list, returning pairs where p1 < p2. (However, under the hood the result list is constructed by appending to a ListBuffer rather than with ::, so the pairs are returned in the reverse order.)

The for-comprehension syntax isn’t specific to lists. It’s syntactic sugar which translates to method calls, so we can use it on any objects which implement the right methods. The methods we need are

  def filter(p: A => Boolean): T[A] 
  def map[B](f: A => B): T[B] 
  def flatMap[B](f: A => T[B]): T[B] 
  def withFilter(p: A => Boolean): T[A] 

where T is some type constructor, like List. For List, filter and map have their ordinary meaning, and flatMap is a map (where the result type must be a list) which concatenates the resulting lists (that is, it flattens the list of lists).

WithFilter is like filter but should be implemented as a “virtual” filter for efficiency—for List it doesn’t build a new filtered list, but instead just keeps track of the filter function; this way multiple adjacent filters can be combined and the result produced with a single pass over the list.

The details of the translation are in the Scala reference manual, section 6.19. Roughly speaking, <- becomes flatMap, if becomes filter, and yield becomes map. So another way to write chooseTwo is:

  def chooseTwo(list: List[Person]): List[(Person,Person)] = 
    list.flatMap(p1 => 
      list.filter(p2 => p1 < p2).map(p2 => (p1, p2))) 
The logic monad

So far we have taken a concrete view of the choices that arise in searching the state tree, by representing a choice among alternatives as a list. For example, in the chooseTwo function we returned a list of alternative pairs. I want now to take a more abstract view, and define an abstract type T[A] to represent a choice among alternatives of type A, along with operations on the type, packaged into a trait:

trait Logic { L => 
  type T[A] 
 
  def fail[A]: T[A] 
  def unit[A](a: A): T[A] 
  def or[A](t1: T[A], t2: => T[A]): T[A] 
  def apply[A,B](t: T[A], f: A => B): T[B] 
  def bind[A,B](t: T[A], f: A => T[B]): T[B] 
  def filter[A](t: T[A], p: A => Boolean): T[A] 
  def split[A](t: T[A]): Option[(A,T[A])] 

A fail value is a choice among no alternatives. A unit(a) is a choice of a single alternative. The value or(t1, t2) is a choice among the alternatives represented by t1 together with the alternatives represented by t2.

The meaning of applying a function to a choice of alternatives is a choice among the results of applying the function to each alternative; that is, if t represents a choice among 1, 2, and 3, then apply(t, f) represents a choice among f(1), f(2), and f(3).

Bind is the same except the function returns a choice of alternatives, so we must combine all the alternatives in the result; that is, if t is a choice among 1, 3, and 5, and f is { x => or(unit(x), unit(x + 1)) }, then bind(t, f) is a choice among 1, 2, 3, 4, 5, and 6.

A filter of a choice of alternatives by a predicate is a choice among only the alternatives which pass the the predicate.

Finally, split is a function which returns the first alternative in a choice of alternatives (if there is at least one) along with a choice among the remaining alternatives.

  def or[A](as: List[A]): T[A] = 
    as.foldRight(fail[A])((a, t) => or(unit(a), t)) 
 
  def run[A](t: T[A], n: Int): List[A] = 
    if (n <= 0) Nil else 
      split(t) match { 
        case None => Nil 
        case Some((a, t)) => a :: run(t, n - 1) 
      } 

As a convenience, or(as: List[A]) means a choice among the elements of as. And run returns a list of the first n alternatives in a choice, picking them off one by one with split; this is how we get answers out of a T[A].

  case class Syntax[A](t: T[A]) { 
    def map[B](f: A => B): T[B] = L.apply(t, f) 
    def filter(p: A => Boolean): T[A] = L.filter(t, p) 
    def flatMap[B](f: A => T[B]): T[B] = L.bind(t, f) 
    def withFilter(p: A => Boolean): T[A] = L.filter(t, p) 
 
    def |(t2: => T[A]): T[A] = L.or(t, t2) 
  } 
 
  implicit def syntax[A](t: T[A]) = Syntax(t) 
} 

Here we hook into the for-comprehension notation, by wrapping values of type T[A] in an object with the methods we need (and | as an additional bit of syntactic sugar), which methods just delegate to the functions defined above. We arrange with an implicit conversion for these wrappers to spring into existence when we need them.

The bridge puzzle with the logic monad

Now we can rewrite the solution in terms of the Logic trait:

class Bridge(Logic: Logic) { 
  import Logic._ 

We pass an implementation of the logic monad in, then open it so the implicit conversion is available (we can also use T[A] and the Logic functions without qualification).

The Person, times, and State definitions are as before.

  private def chooseTwo(list: List[Person]): T[(Person,Person)] = 
    for { p1 <- or(list); p2 <- or(list); if p1 < p2 } 
    yield (p1, p2) 

As we saw, we can write chooseTwo more straightforwardly using a for-comprehension. In the previous version we punned on list as a concrete list and as a choice among alternatives; here we convert one to the other explicitly.

  private def next(state: State): T[State] = { 
    if (state.lightOnLeft) { 
      for { 
        (p1, p2) <- chooseTwo(state.left) 
        timeRemaining = 
          state.timeRemaining - math.max(times(p1), times(p2)) 
        if timeRemaining >= 0 
      } yield { 
        val left = 
          state.left.filterNot { p => p == p1 || p == p2 } 
        State(left, false, timeRemaining) 
      } 
    } else { // ... 

This is pretty much as before, except with for-comprehensions instead of foldLeft and explicit consing. (You can easily figure out the branch for the flashlight on the right.)

  private def tree(path: List[State]): T[List[State]] = 
    unit(path) | 
      (for { 
         state <- next(path.head) 
         path <- tree(state :: path) 
       } yield path) 
 
  def search(n: Int): List[List[State]] = { 
    val start = List(State(Person.all, true, 60)) 
    val t = 
      for { path <- tree(start); if path.head.left == Nil } 
      yield path 
    run(t, n) 
  } 
} 

In tree we use | to adjoin the input path (previously we gave it in the initial value of foldLeft). In search we need to actually run the Logic.T[A] value rather than returning it, because it’s an abstract type and can’t escape the module (see the Postscript for an alternative); this is why the other methods must be private.

Implementing the logic monad with lists

We can recover the original solution by implementing Logic with lists:

object LogicList extends Logic { 
  type T[A] = List[A] 
 
  def fail[A] = Nil 
  def unit[A](a: A) = a :: Nil 
  def or[A](t1: List[A], t2: => List[A]) = t1 ::: t2 
  def apply[A,B](t: List[A], f: A => B) = t.map(f) 
  def bind[A,B](t: List[A], f: A => List[B]) = t.flatMap(f) 
  def filter[A](t: List[A], p: A => Boolean) = t.filter(p) 
  def split[A](t: List[A]) = 
    t match { 
      case Nil => None 
      case h :: t => Some(h, t) 
    } 
} 

A choice among alternatives is just a List of the alternatives, so the semantics we sketched above are realized in a very direct way.

The downside to the List implementation is that we compute all the alternatives, even if we only care about one of them. (In the bridge problem any path to the final state is a satisfactory answer, but our program computes all such paths, even if we pass an argument to search requesting only one answer.) We might even want to solve problems with an infinite number of solutions.

Next time we’ll repair this downside by implementing the backtracking monad from the paper by Kiselyov et al.

See the complete code here.

Postscript: modules in Scala

I got the idea of implementing the for-comprehension methods as an implict wrapper from Edward Kmett’s functorial library. It’s nice that T[A] remains completely abstract, and the for-comprehension notation is just sugar. I also tried an implementation where T[A] is bounded by a trait containing the methods:

trait Monadic[T[_], A] { 
  def map[B](f: A => B): T[B] 
  def filter(p: A => Boolean): T[A] 
  def flatMap[B](f: A => T[B]): T[B] 
  def withFilter(p: A => Boolean): T[A] 
 
  def |(t: => T[A]): T[A] 
  def split: Option[(A,T[A])] 
} 
 
trait Logic { 
  type T[A] <: Monadic[T, A] 
  // no Syntax class needed 

This works too but the type system hackery is a bit ugly, and it constrains implementations of Logic more than is necessary.

Another design choice is whether T[A] is an abstract type (as I have it) or a type parameter of Logic:

trait Logic[T[_]] { L => 
  // no abstract type T[A] but otherwise as before 
} 

Neither alternative provides the expressivity of OCaml modules (but see addendum below): with abstract types, consumers of Logic cannot return values of T[A] (as we saw above); with a type parameter, they can, but the type is no longer abstract.

In OCaml we would write

module type Logic = 
sig 
  type 'a t 
 
  val unit : 'a -> 'a t 
  (* and so on *) 
end 
 
module Bridge(L : Logic) = 
struct 
  type state = ... 
  val search : state list L.t 
end 

and get both the abstract type and the ability to return values of the type.

Addendum

Jorge Ortiz points out in the comments that it is possible to keep T[A] abstract and also return its values from Bridge, by making the Logic argument a (public) val. We can then remove the privates, and write search as just:

  def search: T[List[State]] = { 
    val start = List(State(Person.all, true, 60)) 
    for { path <- tree(start); if path.head.left == Nil } 
    yield path 
  } 

instead of baking run into it. Now, if we write val b = new Bridge(LogicList) then b.search has type b.Logic.T[List[b.State]], and we can call b.Logic.run to evaluate it.

This is only a modest improvement; what’s still missing, compared to the OCaml version, is the fact that LogicList and b.Logic are the same module. So we can’t call LogicList.run(b.search) directly. Worse, we can’t compose modules which use the same Logic implementation, because they each have their own incompatibly-typed Logic member.

I thought there might be a way out of this using singleton types—the idea is that a match of a value v against a typed pattern where the type is w.type succeeds when v eq w (section 8.2 in the reference manual). So we can define

def run[A]( 
  Logic: Logic, 
  b: Bridge, 
  t: b.Logic.T[A], 
  n: Int): List[A] = 
{ 
  Logic match { 
    case l : b.Logic.type => l.run(t, n) 
  } 
} 

which is accepted, but sadly

scala> run[List[b.State]](LogicList, b, b.search, 2) 
<console>:8: error: type mismatch; 
 found   : b.Logic.T[List[b.State]] 
 required: b.Logic.T[List[b.State]] 
       run[List[b.State]](LogicList, b, b.search, 2) 
                                          ^ 

Addendum addendum

Some further advice from Jorge Ortiz: the specific type of Logic (not just Logic.type) can be exposed outside Bridge either through polymorphism:

class Bridge[L <: Logic](val Logic: L) { 
  ... 
} 
 
val b = new Bridge(LogicList) 

or by defining an abstract value (this works the same if Bridge is a trait):

abstract class Bridge { 
  val Logic: Logic 
  ... 
} 

So we can compose uses of T but it remains abstract.