Wednesday, June 8, 2011

Logic programming in Scala, part 3: unification and state

In this post I want to build on the backtracking logic monad we covered last time by adding unification, yielding an embedded DSL for Prolog-style logic programming.

Prolog

Here is a small Prolog example, the rough equivalent of List.contains in Scala:

  member(X, [X|T]). 
  member(X, [_|T]) :- member(X, T). 

Member doesn’t return a boolean; instead it succeeds or fails (in the same way as the logic monad). The goal member(1, [1,2,3]) succeeds; the goal member(4, [1,2,3]) fails. (What happens for member(1, [1,1,3])?)

A Prolog predicate is defined by one or more clauses (each ending in a period), made up of a head (the predicate and arguments before the :-) and zero or more subgoals (goals after the :-, separated by commas; if there are no subgoals the :- is omitted). To solve a goal, we unify it (match it) with each clause head, then solve each subgoal in the clause. If a subgoal fails we backtrack and try the next matching head; if there is no matching head the goal fails. A goal may succeed more than once.

For member we have two clauses: the first says that member succeeds if X is the head of the list ([X|T] is the same as x::t in Scala); the second says that member succeeds if X is a member of the tail of the list, regardless of the head. There is no clause where the list is empty (written []); a goal with an empty list fails since there is no matching clause head.

Prolog unification is more expressive than pattern matching as found in Scala, OCaml, etc. Both sides of a unification may contain variables; unification attempts to instantiate them so that the two sides are equal. Variables are instantiated by terms, which themselves may contain variables; unification finds the most general instantiation which makes the sides equal.

As a small example of this expressivity, we can run member “backwards”: the goal member(X, [1,2,3]) succeeds once for each element of the list, with X bound to the element.

There is much more on Prolog and logic programming in Frank Pfenning’s course notes, which I recommend highly.

Unification

For each type we want to use in unification we’ll define a corresponding type of terms, which have the same structure as the underlying type but can also contain variables. These aren’t Scala variables (which of course can’t be stored in a data structure) but “existential variables”, or evars. Evars are just tags; computations will carry an environment mapping evars to terms, which may be updated after a successful unification.

import scala.collection.immutable.{Map,HashMap} 
 
class Evar[A](val name: String) 
object Evar { def apply[A](name: String) = new Evar[A](name) } 
 
trait Term[A] { 
  // invariant: on call to unify, this and t have e substituted 
  def unify(e: Env, t: Term[A]): Option[Env] 
 
  def occurs[B](v: Evar[B]): Boolean 
  def subst(e: Env): Term[A] 
  def ground: A 
} 

The important property of an evar is that it is distinct from every other evar; the name attached to it is just a label. An evar is indexed by a phantom type indicating the underlying type of terms which may be bound to it.

A term is indexed by its underlying type. So Int becomes Term[Int], String becomes Term[String], and so on; an evar of type Evar[A] may only be bound to a term of type Term[A]. (Prolog is dynamically typed, but this statically-typed treatment of evars and terms fits better with Scala.)

The unify method unifies a term with another term of the same type, taking an environment and returning an updated environment (or None if the unification fails). Occurs checks if an evar occurs in a term (as we will see this is used to prevent circular bindings). Subst substitutes the variables in a term with their bindings in an environment, and ground returns the underlying Scala value represented by the term (provided the term contains no evars).

class Env(m: Map[Evar[Any],Term[Any]]) { 
  def apply[A](v: Evar[A]) = 
    m(v.asInstanceOf[Evar[Any]]).asInstanceOf[Term[A]] 
  def get[A](v: Evar[A]): Option[Term[A]] = 
    m.get(v.asInstanceOf[Evar[Any]]).asInstanceOf[Option[Term[A]]] 
  def updated[A](v: Evar[A], t: Term[A]): Env = { 
    val v2 = v.asInstanceOf[Evar[Any]] 
    val t2 = t.asInstanceOf[Term[Any]] 
    val e2 = Env(Map(v2 -> t2)) 
    val m2 = m.mapValues(_.subst(e2)) 
    Env(m2.updated(v2, t2)) 
  } 
} 
object Env { 
  def apply(m: Map[Evar[Any],Term[Any]]) = new Env(m) 
  def empty = new Env(HashMap()) 
} 

An environment is just a map from evars to terms. Because we need to store evars and terms of different types in the same environment, we cast them to and from Any; this is safe because of the phantom type on Evar. For simplicity we maintain the invariant that the term bound to each evar is already substituted by the rest of the environment.

case class VarTerm[A](v: Evar[A]) extends Term[A] { 
  def unify(e: Env, t: Term[A]) = 
    t match { 
      case VarTerm(v2) if (v2 == v) => Some(e) 
      case _ => 
        if (t.occurs(v)) None 
        else Some(e.updated(v, t)) 
    } 
 
  def occurs[B](v2: Evar[B]) = v2 == v 
 
  def subst(e: Env) = 
    e.get(v) match { 
      case Some(t) => t 
      case None => this 
    } 
 
  def ground = 
    throw new IllegalArgumentException("not ground") 
 
  override def toString = { v.name  } 
} 

The VarTerm class represents terms consisting of an evar. To unify a VarTerm with another VarTerm containing the same evar, we just return the environment unchanged (since there is no new information). Otherwise we check that the evar doesn’t appear in the term (since a unification x =:= List(x) would create a circular term) then return the updated environment.

To substitute a VarTerm we return the term bound to the evar in the environment if one exists, otherwise the unsubstituted VarTerm. A VarTerm is never ground (we assume ground is called only on terms which are already substituted by the environment).

case class LitTerm[A](a: A) extends Term[A] { 
  def unify(e: Env, t: Term[A]) = 
    t match { 
      case LitTerm(a2) => if (a == a2) Some(e) else None 
      case _: VarTerm[_] => t.unify(e, this) 
      case _ => None 
    } 
 
  def occurs[B](v: Evar[B]) = false 
  def subst(e: Env) = this 
  def ground = a 
 
  override def toString = { a.toString } 
} 

LitTerm represents terms of literal Scala values. A LitTerm unifies with another LitTerm containing an equal value, but that adds nothing to the environment. Then we have two cases which we need for every term type—to unify with a VarTerm call unify back on it; otherwise fail.

case class NilTerm[A]() extends Term[List[A]] { 
  def unify(e: Env, t: Term[List[A]]) = 
    t match { 
      case NilTerm() => Some(e) 
      case _: VarTerm[_] => t.unify(e, this) 
      case _ => None 
    } 
 
  def occurs[B](v: Evar[B]) = false 
  def subst(e: Env) = this 
  def ground = Nil 
 
  override def toString = { Nil.toString } 
} 
 
case class ConsTerm[A](hd: Term[A], tl: Term[List[A]]) 
  extends Term[List[A]] 
{ 
  def unify(e: Env, t: Term[List[A]]) = 
    t match { 
      case ConsTerm(hd2, tl2) => 
        for { 
          e1 <- hd.unify(e, hd2) 
          e2 <- tl.subst(e1).unify(e1, tl2.subst(e1)) 
        } yield e2 
      case _: VarTerm[_] => t.unify(e, this) 
      case _ => None 
    } 
 
  def occurs[C](v: Evar[C]) = hd.occurs(v) || tl.occurs(v) 
  def subst(e: Env) = ConsTerm(hd.subst(e), tl.subst(e)) 
  def ground = hd.ground :: tl.ground 
 
  override def toString = { hd.toString + " :: " + tl.toString } 
} 

NilTerm and ConsTerm represent the Nil and :: constructors for lists. Nil is sort of like a literal, so the methods for NilTerm are similar to those for LitTerm. For ConsTerm we unify by unifying the heads and tails, calling subst on the tails since unifying the heads may have added bindings to the environment. (Here it’s convenient to use a for-comprehension on the Option[Env] type since either unification may fail.) Similarly we implement occurs, subst, and ground by calling them on the head and tail.

object Term { 
  implicit def var2Term[A](v: Evar[A]): Term[A] = VarTerm(v) 
  //implicit def lit2term[A](a: A): Term[A] = LitTerm(a) 
  implicit def int2Term(a: Int): Term[Int] = LitTerm(a) 
  implicit def list2Term[A](l: List[Term[A]]): Term[List[A]] = 
    l match { 
      case Nil => NilTerm[A] 
      case hd :: tl => ConsTerm(hd, list2Term(tl)) 
    } 
} 

Finally we have some implicit conversions to make it a little easier to build Term values. The lit2term conversion turned out to be a bad idea; in particular you don’t want a LitTerm[List[A]] since it doesn’t unify with a ConsTerm[A] or NilTerm[A].

State

In order to combine unification with backtracking, we need to keep track of the environment along each branch of the tree of choices. We don’t want the environments from different branches to interfere, so it’s convenient to use a purely functional environment representation; we pass the current environment down the tree as computation proceeds. However, we can hide this state passing in the monad interface:

trait LogicState { L => 
  type T[S,A] 
  // as before 
  def split[S,A](s: S, t: T[S,A]): Option[(S,A,T[S,A])] 
 
  def get[S]: T[S,S] 
  def set[S](s: S): T[S, Unit] 
 
  case class Syntax[S,A](t: T[S,A]) { 
    // as before 
    def &[B](t2: => T[S,B]): T[S,B] = L.bind(t, { _: A => t2 }) 
  } 
} 

LogicState is mostly the same as Logic, except that the type of choices has an extra parameter for the type of the state. The get and set functions get and set the current state. To split we need an initial state to get things started, and each result includes an updated state. Finally we add the syntax & to sequence two computations, ignoring the value of the first. We’ll use this to sequence goals, since we care only about the updated environment.

The simplest implementation of LogicState builds on Logic:

trait LogicStateT extends LogicState { 
  val Logic: Logic 
 
  type T[S,A] = S => Logic.T[(S, A)] 

We embed state-passing in a Logic.T as a function from an initial state to a choice of alternatives, where each alternative includes an updated state along with its value.

  def fail[S,A] = { s: S => Logic.fail } 
  def unit[S,A](a: A) = { s: S => Logic.unit((s, a)) } 
 
  def or[S,A](t1: T[S,A], t2: => T[S,A]) = 
    { s: S => Logic.or(t1(s), t2(s)) } 
 
  def bind[S,A,B](t: T[S,A], f: A => T[S,B]) = { 
    val f2: ((S,A)) => Logic.T[(S,B)] = { case (s, a) => f(a)(s) } 
    { s: S => Logic.bind(t(s), f2) } 
  } 
 
  def apply[S,A,B](t: T[S,A], f: A => B) = { 
    val f2: ((S,A)) => ((S,B)) = { case (s, a) => (s, f(a)) } 
    { s: S => Logic.apply(t(s), f2) } 
  } 
 
  def filter[S,A](t: T[S,A], p: A => Boolean) = { 
    val p2: ((S,A)) => Boolean = { case (_, a) => p(a) } 
    { s: S => Logic.filter(t(s), p2) } 
  } 

All of these operations pass the state through unchanged. Note that or passes the same state to both alternatives—different branches of the tree cannot interfere with one another’s state.

  def split[S,A](s: S, t: T[S,A]) = { 
    Logic.split(t(s)) match { 
      case None => None 
      case Some(((s, a), t)) => Some((s, a, { _ => t })) 
    } 
  } 
 
  def get[S] = { s: S => Logic.unit((s,s)) } 
  def set[S](s: S) = { _: S => Logic.unit((s,())) } 
} 

In split we pass the given state to the underlying Logic.T, and for each alternative we unpack the pair of state and value. The choice of remaining alternatives t encapsulates the current state, so when we return it we ignore the input state. In get and set we return and replace the current state.

Another approach is to pass state explicitly through LogicSFK:

object LogicStateSFK extends LogicState { 
  type FK[R] = () => R 
  type SK[S,A,R] = (S, A, FK[R]) => R 
 
  trait T[S,A] { def apply[R](s: S, sk: SK[S,A,R], fk: FK[R]): R } 

This is not really any different from LogicStateT applied to LogicSFK—we have just uncurried the state argument. We can take the same path as last time and defunctionalize this into a tail-recursive implementation (see the full code) although LogicStateT applied to LogicSFKDefuncTailrec inherits tail-recursiveness from the underlying Logic monad.

Scrolog

Finally we can put the pieces together into a Prolog-like embedded DSL:

trait Scrolog { 
  val LogicState: LogicState 
  import LogicState._ 
 
  type G = T[Env,Unit] 

From our point of view, a goal is a stateful choice among alternatives, where we don’t care about the value returned, only the environment.

  class TermSyntax[A](t: Term[A]) { 
    def =:=(t2: Term[A]): G = 
      for { 
        env <- get 
        env2 <- { 
          t.subst(env).unify(env, t2.subst(env)) match { 
            case None => fail[Env,Unit] 
            case Some(e) => set(e) 
          } 
        } 
      } yield env2 
  } 
 
  implicit def termSyntax[A](t: Term[A]) = new TermSyntax(t) 
  implicit def syntax[A](t: G) = LogicState.syntax(t) 

We connect term unification to the stateful logic monad with a wrapper class defining a =:= operator. To unify terms in the monad, we get the current environment, substitute it into the two terms (to satisfy the invariant above), then call unify; if it fails we fail the computation, else we set the new state.

  def run[A](t: G, n: Int, tm: Term[A]): List[Term[A]] = 
    LogicState.run(Env.empty, t, n) 
      .map({ case (e, _) => tm.subst(e) }) 
} 

The run function solves a goal, taking as arguments the goal, the maximum number of solutions to find, and a term to be evaluated in the environment of each solution.

Examples

First we need to set up Scrolog:

val Scrolog = 
  new Scrolog { val LogicState = 
    new LogicStateT { val Logic = LogicSFKDefuncTailrec } 
  } 
import Scrolog._ 

Here is a translation of the member predicate:

  def member[A](x: Term[A], l: Term[List[A]]): G = { 
    val hd = Evar[A]("hd"); val tl = Evar[List[A]]("tl") 
    ConsTerm(x, tl) =:= l | 
    (ConsTerm(hd, tl) =:= l & member(x, tl)) 
  } 

We implement predicates by functions, and goals by function calls. To implement matching the clause head, we explicitly unify the input arguments against each clause head, and combine the clauses with |. Subgoals are sequenced with &. Finally, we must create local evars explicitly, since they are fresh for each call (just as local variables are in Scala).

Finally we can run the goal above:

scala> val x = Evar[Int]("x") 
scala> run(member(x, List[Term[Int]](1, 2, 3)), 3, x) 
res6: List[Term[Int]] = List(1, 2, 3) 

As another example, we can implement addition over unary natural numbers. In Prolog this would be

  sum(z, N, N). 
  sum(s(M), N, s(P)) :- sum(M, N, P). 

In Prolog we can just invent symbols like s and z; in Scala we need first to define a type of natural numbers, then terms over that type:

  sealed trait Nat 
  case object Z extends Nat 
  case class S(n: Nat) extends Nat 
 
  case object ZTerm extends Term[Nat] { 
    // like NilTerm 
 
  case class STerm(n: Term[Nat]) extends Term[Nat] { 
    // like ConsTerm 

Then we can define sum, again separating the clauses by | and explicitly unifying the clause heads:

  def sum(m: Term[Nat], n: Term[Nat], p: Term[Nat]): G = { 
    val m2 = Evar[Nat]("m"); val p2 = Evar[Nat]("p") 
    (m =:= Z & n =:= p) | 
    (m =:= STerm(m2) & p =:= STerm(p2) & sum(m2, n, p2)) 
  } 

We can use sum to do addition:

scala> val x = Evar[Nat]("x"); val y = Evar[Nat]("y") 
scala> run(sum(S(Z), S(S(Z)), x), 1, x) 
res8: List[Term[Nat]] = List(S(S(S(Z)))) 

or subtraction:

scala> run(sum(x, S(S(Z)), S(S(S(Z)))), 1, x) 
res10: List[Term[Nat]] = List(S(Z)) 
 
scala> run(sum(S(Z), x, S(S(S(Z)))), 1, x) 
res11: List[Term[Nat]] = List(S(S(Z))) 

or even to find all the pairs of naturals which sum to 3:

scala> run(sum(x, y, S(S(S(Z)))), 10, List[Term[Nat]](x, y)) 
res14: List[Term[List[Nat]]] = 
  List(Z :: S(S(S(Z))) :: List(), 
       S(Z) :: S(S(Z)) :: List(), 
       S(S(Z)) :: S(Z) :: List(), 
       S(S(S(Z))) :: Z :: List()) 

although the printing of Term[List] could be better.

This is only a small taste of the expressivity of Prolog-style logic programming. Again let me recommend Frank Pfenning’s course notes, which explore the semantics of Prolog in a “definitional interpreters” style, by gradually refining an interpreter to expose more of the machinery of the language.

See the full code.

Friday, April 29, 2011

Logic programming in Scala, part 2: backtracking

In the previous post we saw how to write computations in a logic monad, where a “value” is a choice among alternatives, and operating on a value means operating on all the alternatives.

Our first implementation of the logic monad represents a choice among alternatives as a list, and operating on a value means running the operation for each alternative immediately (to produce a new list of alternatives). If we imagine alternatives as leaves of a tree (with | indicating branching), the first implementation explores the tree breadth-first.

This is OK for some problems, but we run into trouble when there are a large or infinite number of alternatives. For example, a choice among the natural numbers:

scala> import LogicList._ 
import LogicList._ 
 
scala> val nat: T[Int] = unit(0) | nat.map(_ + 1) 
java.lang.NullPointerException 
        ... 

This goes wrong because even though the right-hand argument to | is by-name, we immediately try to use it, and fail because nat is not yet defined.

scala> def nat: T[Int] = unit(0) | nat.map(_ + 1) 
nat: LogicList.T[Int] 
scala> run(nat, 10) 
java.lang.StackOverflowError 
        ... 

With def we can successfully define nat, because the right-hand side isn’t evaluated until nat is used in the call to run, but we overflow the stack trying to compute all the natural numbers.

Let’s repair this with a fancier implementation of the logic monad, translated from Kiselyov et al.’s Backtracking, Interleaving, and Terminating Monad Transformers. This implementation will explore the tree depth-first.

Success and failure continuations

The idea is to represent a choice of alternatives by a function, which takes as arguments two functions: a success continuation and a failure continuation. The success continuation is just a function indicating what to do next with each alternative; the failure continuation is what to do next when there are no more alternatives.

For success, what we do next is either return the alternative (when we have reached a leaf of the tree), or perform some operation on it (possibly forming new branches rooted at the alternative). For failure, what we do next is back up to the last branch point and succeed with the next alternative. If there are no more alternatives at the previous branch point we back up again, and so on until we can succeed or finally run out of alternatives. In other words, we do depth-first search on the tree, except that the tree isn’t a materialized data structure—it’s created on the fly.

(In the jargon of logic programming, a branch point is called a “choice point”, and going back to an earlier choice point is called “backtracking”.)

object LogicSFK extends Logic { 
  type FK[R] = () => R 
  type SK[A,R] = (A, FK[R]) => R 
 
  trait T[A] { def apply[R](sk: SK[A,R], fk: FK[R]): R } 

The continuations can return a result of some arbitrary type R. This means that the function representing a choice has a “rank-2” polymorphic type—it takes functions which are themselves polymorphic—which is not directly representable in Scala. But we can encode it by making the representation function a method on a trait.

The success continuation takes a value of the underlying type (i.e. an alternative), and also a failure continuation, to call in case this branch of the tree eventually fails (by calling fail, or filter when no alternative satisfies the predicate). The failure continuation is also called to succeed with the next alternative after returning a leaf (see split).

  def fail[A] = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = fk() 
    } 
 
  def unit[A](a: A) = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = sk(a, fk) 
    } 

To fail, just call the failure continuation. To succeed with one alternative, call the success continuation with the single alternative and the passed-in failure continuation—there are no more alternatives to try, so if this branch fails the unit fails.

  def or[A](t1: T[A], t2: => T[A]) = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = 
        t1(sk, { () => t2(sk, fk) }) 
    } 

Or creates a choice point. We want to explore the alternatives in both t1 and t2, so we pass the success continuation to t1 (which calls it on each alternative); when t1 is exhausted we pass the success continuation to t2; finally we fail with the caller’s failure continuation—that is, we backtrack.

  def bind[A,B](t: T[A], f: A => T[B]) = 
    new T[B] { 
      def apply[R](sk: SK[B,R], fk: FK[R]) = 
        t(({ (a, fk) => f(a)(sk, fk) }: SK[A,R]), fk) 
    } 
 
  def apply[A,B](t: T[A], f: A => B) = 
    new T[B] { 
      def apply[R](sk: SK[B,R], fk: FK[R]) = 
        t(({ (a, fk) => sk(f(a), fk) }: SK[A,R]), fk) 
    } 

For bind we extend each branch by calling f on the current leaf. To succeed we call f on the alternative a. Now f(a) returns a choice of alternatives, so we pass it the original success continuation (which says what to do next with alternatives resulting from the bind), and the failure continuation in force at the point a was generated (which succeeds with the next available alternative from f(a)).

For apply things are simpler, since f(a) returns a single value rather than a choice of alternatives: we succeed immediately with the returned value.

  def filter[A](t: T[A], p: A => Boolean) = 
    new T[A] { 
      def apply[R](sk: SK[A,R], fk: FK[R]) = { 
        val sk2: SK[A,R] = 
          { (a, fk) => if (p(a)) sk(a, fk) else fk() } 
        t(sk2, fk) 
      } 
    } 

To filter a choice of alternatives, each time we succeed with a value we see if it satisfies the predicate p; if it does, we succeed with that value (extending the branch), otherwise we fail (pruning the branch).

  def split[A](t: T[A]) = { 
    def unsplit(fk: FK[Option[(A,T[A])]]): T[A] = 
      fk() match { 
        case None => fail 
        case Some((a, t)) => or(unit(a), t) 
      } 
    def sk : SK[A,Option[(A,T[A])]] = 
      { (a, fk) => Some((a, bind(unit(fk), unsplit))) } 
    t(sk, { () => None }) 
  } 
} 

The point of split is to pull a single alternative from a choice, returning along with it a choice of the remaining alternatives. In the list implementation we just returned the head and tail of the list. In this implementation, the alternatives are computed on demand; we want to be careful to do only as much computation as needed to pull the first alternative

The failure continuation we pass to t just returns None when there are no more alternatives. The success continuation sk returns the first alternative and a choice of the remaining alternatives (wrapped in Some).

The tricky part is the choice of remaining alternatives. We’re given the failure continuation fk; calling it calls sk on the next alternative, which ultimately returns Some(a, t) where a is the next alternative, or None if there are no more alternatives. We repackage this Option as a choice of alternatives with unsplit. So that we don’t call fk too soon, we call unsplit via bind, which defers the call until the resulting choice of alternatives is actually used.

Now we can write infinite choices:

scala> import LogicSFK._ 
import LogicSFK._ 
 
scala> val nat: T[Int] = unit(0) | nat.map(_ + 1) 
nat: LogicSFK.T[Int] = LogicSFK$$anon$3@27aea0c1 
 
scala> run(nat, 10) 
res1: List[Int] = List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) 

Well, this is a pretty complicated way to generate the natural numbers up to 10…

While nat looks like a lazy stream (as you might write in Haskell), no results are memoized (as they are in Haskell). To compute each successive number all the previous ones must be recomputed, and the running time of run(nat, N) is O(N2).

Defunctionalization

The code above is a fairly direct translation of the Haskell code from the paper. But its use of continuation-passing style doesn’t map well to Scala, because Scala doesn’t implement tail-call elimination (because the JVM doesn’t). Every call to a success or failure continuation adds a frame to the stack, even though all we ever do with the result is return it (i.e. the call is in tail-position), so the stack frame could be eliminated.

Surprisingly, we run out of memory before we run out of stack:

scala> run(nat, 2000) 
java.lang.OutOfMemoryError: Java heap space 
 ... 

A little heap profiling shows that we’re using quadratic space as well as quadratic time. It turns out that the implementation of Logic.run (from the previous post) has a space leak. The call to run is not tail-recursive, so the stack frame hangs around, and although t is dead after split(t), there’s still a reference to it on the stack.

We can rewrite run with an accumulator to be tail-recursive:

  def run[A](t: T[A], n: Int): List[A] = { 
    def runAcc(t: T[A], n: Int, acc: List[A]): List[A] = 
      if (n <= 0) acc.reverse else 
        split(t) match { 
          case None => acc.reverse 
          case Some((a, t)) => runAcc(t, n - 1, a :: acc) 
        } 
    runAcc(t, n, Nil) 
  } 

Now scalac compiles runAcc as a loop, so there are no stack frames holding on to dead values of t, and we get the expected:

scala> run(nat, 9000) 
java.lang.StackOverflowError 
 ... 

To address the stack overflow we turn to defunctionalization. The idea (from John Reynold’s classic paper Definitional Interpreters for Higher-Order Programming Languages) is to replace functions and their applications with data constructors (we’ll use case classes) and an apply function, which matches the data constructor and does whatever the corresponding function body does. If a function captures variables, the data constructor must capture the same variables.

After defunctionalization we’re left with three mutually recursive apply functions (one for each of T, FK, and SK) where each recursive call is in tail position. In theory the compiler could transform these into code that takes only constant stack space (since they are local functions private to split). But in fact it will do so only for single recursive functions, so we will need to do this transformation by hand.

There is one hitch: the original code is not completely tail-recursive, because of unsplit, which calls a failure continuation then matches on the result. To fix this we need to add yet another continuation, which represents what to do after returning a result from a success or failure continuation.

object LogicSFKDefunc extends Logic { 
  type O[A] = Option[(A,T[A])] 
 
  sealed trait T[A] 
  case class Fail[A]() extends T[A] 
  case class Unit[A](a: A) extends T[A] 
  case class Or[A](t1: T[A], t2: () => T[A]) extends T[A] 
  case class Bind[A,B](t: T[A], f: A => T[B]) extends T[B] 
  case class Apply[A,B](t: T[A], f: A => B) extends T[B] 
  case class Filter[A](t: T[A], p: A => Boolean) extends T[A] 
  case class Unsplit[A](fk: FK[O[A]]) extends T[A] 
 
  def fail[A] = Fail() 
  def unit[A](a: A) = Unit(a) 
  def or[A](t1: T[A], t2: => T[A]) = Or(t1, { () => t2 }) 
  def bind[A,B](t: T[A], f: A => T[B]) = Bind(t, f) 
  def apply[A,B](t: T[A], f: A => B) = Apply(t, f) 
  def filter[A](t: T[A], p: A => Boolean) = Filter(t, p) 

A choice of alternatives T[A] is now represented symbolically by case classes, and the functions which operate on choices just return the corresponding case. The cases capture the same variables that were captured in the original functions.

We have an additional case Unsplit which represents the bind(unit(fk), unsplit) combination from split. And we use O[A] as a convenient abbreviation.

  sealed trait FK[R] 
  case class FKOr[A,R](t: () => T[A], sk: SK[A,R], fk: FK[R]) 
    extends FK[R] 
  case class FKSplit[R](r: R) extends FK[R] 
 
  sealed trait SK[A,R] 
  case class SKBind[A,B,R](f: A => T[B], sk: SK[B,R]) 
    extends SK[A,R] 
  case class SKApply[A,B,R](f: A => B, sk: SK[B,R]) 
    extends SK[A,R] 
  case class SKFilter[A,R](p: A => Boolean, sk: SK[A,R]) 
    extends SK[A,R] 
  case class SKSplit[A,R](r: (A, FK[R]) => R) extends SK[A,R] 
 
  sealed trait K[R,R2] 
  case class KReturn[R]() extends K[R,R] 
  case class KUnsplit[A,R,R2](sk: SK[A,R], fk:FK[R], k: K[R,R2]) 
    extends K[O[A],R2] 

Each case for FK (respectively SK) corresponds to a success (respectively failure) continuation function in the original code—it’s easy to match them up.

The K cases are for the new return continuation. They are defunctionalized from functions R => R2; we can either return a value directly, or match on whether it is Some or None as in unsplit. (If K is hard to understand you might try “refunctionalizing” it by replacing the cases with functions.)

We see that case classes are more powerful than variants in OCaml, without GADTs at least. Cases can have “input” type variables (appearing in arguments) which do not appear in the “output” (the type the case extends). When we match on the case these are treated as existentials. And the output type of a case can be more restrictive than type it extends; when we match on the case we can make more restrictive assumptions about types in that branch of the match. More on this in Emir, Odersky, and Williams’ Matching Objects with Patterns.

  def split[A](t: T[A]) = { 
 
    def applyT[A,R,R2] 
      (t: T[A], sk: SK[A,R], fk: FK[R], k: K[R,R2]): R2 = 
      t match { 
        case Fail() => applyFK(fk, k) 
        case Unit(a) => applySK(sk, a, fk, k) 
        case Or(t1, t2) => applyT(t1, sk, FKOr(t2, sk, fk), k) 
        case Bind(t, f) => applyT(t, SKBind(f, sk), fk, k) 
        case Apply(t, f) => applyT(t, SKApply(f, sk), fk, k) 
        case Filter(t, p) => applyT(t, SKFilter(p, sk), fk, k) 
        case Unsplit(fk2) => applyFK(fk2, KUnsplit(sk, fk, k)) 
      } 
 
    def applyFK[R,R2](fk: FK[R], k: K[R,R2]): R2 = 
      fk match { 
        case FKOr(t, sk, fk) => applyT(t(), sk, fk, k) 
        case FKSplit(r) => applyK(k, r) 
      } 
 
    def applySK[A,R,R2] 
      (sk: SK[A,R], a: A, fk: FK[R], k: K[R,R2]): R2 = 
      sk match { 
        case SKBind(f, sk) => applyT(f(a), sk, fk, k) 
        case SKApply(f, sk) => applySK(sk, f(a), fk, k) 
        case SKFilter(p, sk) => 
          if (p(a)) applySK(sk, a, fk, k) else applyFK(fk, k) 
        case SKSplit(rf) => applyK(k, rf(a, fk)) 
      } 

Again, each of these cases corresponds directly to a function in the original code, and again it is easy to match them up (modulo the extra return continuation argument) to see that all we have done is separated the data part of the function (i.e. the captured variables) from the code part.

The exception is Unsplit, which again corresponds to bind(unit(fk), unsplit). To apply it, we apply fk (which collapses unit(fk), bind, and the application of fk in unsplit) with KUnsplit as continuation, capturing sk, fk, and k (corresponding to their capture in the success continuation of bind).

    def applyK[R,R2](k: K[R,R2], r: R): R2 = 
      k match { 
        case KReturn() => r.asInstanceOf[R2] 
        case KUnsplit(sk, fk, k) => { 
          r match { 
            case None => applyFK(fk, k) 
            case Some((a, t)) => applyT(or(unit(a), t), sk, fk, k) 
          } 
        } 
      } 

For KReturn we just return the result. Although KReturn extends K[R,R], Scala doesn’t deduce from this that R = R2, so we must coerce the result. For KUnsplit we do the same match as unsplit, then apply the resulting T (for the None case we call the failure continuation directly instead of applying fail). Here Scala deduces from the return type of KUnsplit that is safe to treat r as an Option.

    applyT[A,O[A],O[A]]( 
      t, 
      SKSplit((a, fk) => Some((a, Unsplit(fk)))), 
      FKSplit(None), 
      KReturn()) 
  } 
} 

Finally we apply the input T in correspondence to the original split.

Tail call elimination

(This section has been revised; you can see the original here.)

To eliminate the stack frames from tail calls, we next rewrite the four mutually-recursive functions into a single recursive function (which Scala compiles as a loop). To do this we have to abandon some type safety (but only in the implementation of the Logic monad; we’ll still present the same safe interface).

object LogicSFKDefuncTailrec extends Logic { 
  type O[A] = Option[(A,T[A])] 
 
  type T[A] = I 
 
  sealed trait I 
  case class Fail() extends I 
  case class Unit(a: Any) extends I 
  case class Or(t1: I, t2: () => I) extends I 
  case class Bind(t: I, f: Any => I) extends I 
  case class Apply(t: I, f: Any => Any) extends I 
  case class Filter(t: I, p: Any => Boolean) extends I 
  case class Unsplit(fk: I) extends I 
 
  case class FKOr(t: () => I, sk: I, fk: I) extends I 
  case class FKSplit(r: O[Any]) extends I 
 
  case class SKBind(f: Any => I, sk: I) extends I 
  case class SKApply(f: Any => Any, sk: I) extends I 
  case class SKFilter(p: Any => Boolean, sk: I) extends I 
  case class SKSplit(r: (Any, I) => O[Any]) extends I 
 
  case object KReturn extends I 
  case class KUnsplit(sk: I, fk: I, k: I) extends I 

This is all pretty much as before except that we erase all the type parameters. Having done so we can combine the four defunctionalized types into a single type I (for “instruction” perhaps), which will allow us to write a single recursive apply function. The type parameter in T[A] is then a phantom type since it does not appear on the right-hand side of the definition; it is used only to enforce constraints outside the module.

  def fail[A]: T[A] = Fail() 
  def unit[A](a: A): T[A] = Unit(a) 
  def or[A](t1: T[A], t2: => T[A]): T[A] = Or(t1, { () => t2 }) 
  def bind[A,B](t: T[A], f: A => T[B]): T[B] = 
    Bind(t, f.asInstanceOf[Any => I]) 
  def apply[A,B](t: T[A], f: A => B): T[B] = 
    Apply(t, f.asInstanceOf[Any => I]) 
  def filter[A](t: T[A], p: A => Boolean): T[A] = 
    Filter(t, p.asInstanceOf[Any => Boolean]) 

The functions for building T[A] values are mostly the same. We have to cast passed-in functions since Any is not a subtype of arbitrary A. The return type annotations don’t seem necessary but I saw some strange type errors without them (possibly related to the phantom type?) when using the Logic.Syntax wrapper.

def split[A](t: T[A]): O[A] = { 
  def apply(i: I, a: Any, r: O[Any], sk: I, fk: I, k: I): O[Any] = 
    i match { 
      case Fail() => apply(fk, null, null, null, null, k) 
      case Unit(a) => apply(sk, a, null, null, fk, k) 
      case Or(t1, t2) => 
        apply(t1, null, null, sk, FKOr(t2, sk, fk), k) 
      case Bind(t, f) => 
        apply(t, null, null, SKBind(f, sk), fk, k) 
      case Apply(t, f) => 
        apply(t, null, null, SKApply(f, sk), fk, k) 
      case Filter(t, p) => 
        apply(t, null, null, SKFilter(p, sk), fk, k) 
      case Unsplit(fk2) => 
        apply(fk2, null, null, null, null, KUnsplit(sk, fk, k)) 
 
      case FKOr(t, sk, fk) => apply(t(), null, null, sk, fk, k) 
      case FKSplit(r) => apply(k, null, r, null, null, null) 
 
      case SKBind(f, sk) => apply(f(a), null, null, sk, fk, k) 
      case SKApply(f, sk) => apply(sk, f(a), null, null, fk, k) 
      case SKFilter(p, sk) => 
        if (p(a)) 
          apply(sk, a, null, null, fk, k) 
        else 
          apply(fk, null, null, null, null, k) 
      case SKSplit(rf) => 
        apply(k, null, rf(a, fk), null, null, null) 
 
      case KReturn => r 
      case KUnsplit(sk, fk, k) => { 
        r match { 
          case None => apply(fk, null, null, null, null, k) 
          case Some((a, t)) => 
            apply(or(unit(a), t), null, null, sk, fk, k) 
        } 
      } 
    } 
 
  apply(t, 
        null, 
        null, 
        SKSplit((a, fk) => Some((a, Unsplit(fk)))), 
        FKSplit(None), 
        KReturn).asInstanceOf[O[A]] 
} 

The original functions took varying arguments; the single function takes all the arguments which the original ones did. We pass null for unused arguments in each call, but otherwise the cases are the same as before.

Now we can evaluate nat to large N without running out of stack (but since the running time is quadratic it takes longer than I care to wait to complete):

scala> run(nat, 100000) 
^C 

See the complete code here.

Next time we’ll thread state through this backtracking logic monad, and use it to implement unification.