Monday, July 25, 2011

Monad Transformers in Scala

Monads don't compose .. and hence Monad Transformers. A monad transformer maps monads to monads. It lets you transform a monad with additional computational effects. Stated simply, if you have a monadic computation in place you can enrich it incrementally with additional effects like states and errors without disturbing the whole structure of your program.

A monad transformer is represented by the kind T :: (* -> *) -> * -> *. The general contract that a monad transformer offers is ..

class MonadTrans t where
 lift :: (Monad m) => m a -> t m a

Here we lift a computation m a into the context of another effect t. We call t the monad transformer, which is itself a monad.

Well in this post, I will discuss monad transformers in scala using scalaz 7. And here's what you get as the base abstraction corresponding to the Haskell typeclass shown above ..

trait MonadTrans[F[_[_], _]] {
  def lift[G[_] : Monad, A](a: G[A]): F[G, A]
}

It takes a G and lifts it into another computation F thereby combining the effects into the composed monad. Let's look at a couple of examples in scalaz that use lift to compose a new effect into an existing one ..

// lift an Option into a List
// uses transLift available from the main pimp of scalaz

scala> List(10, 12).transLift[OptionT].runT
res12: List[Option[Int]] = List(Some(10), Some(12))

// uses explicit constructor methods

scala> optionT(List(some(12), some(50))).runT
res13: List[Option[Int]] = List(Some(12), Some(50))

If you are like me, you must have already started wondering about the practical usage of this monad transformer thingy. Really we want to see them in action in some meaningful code which we write in our daily lives.

In the paper titled Monad Transformers Step by Step, Martin Grabmuller does this thing in Haskell evolving a complete interpreter of a subset of language using these abstractions and highlighting how they contribute towards an effective functional model of your code. In this post I render some of the scala manifestations of those examples using scalaz 7 as the library of implementation. The important point that Martin also mentions in his paper is that you need to think functionally and organize your code upfront using monadic structures in order to take full advantage of incremental enrichment through monad transformers.

We will be writing an interpreter for a very small language. I will first define the base abstractions and start with a functional code as the implementation base. It does not contain many of the useful stuff like state management, error handling etc., which I will add incrementally using monad transformers. We will see how the core model remains the same, transformers get added in layers and the static type of the interpreter function states explicitly what effects have been added to it.

The Language

Here's the language for which we will be writing the interpreter. Pretty basic stuff with literal integers, variables, addition, λ expressions (abstraction) and function application. By abstraction and application I mean lambda terms .. so a quick definition for the uninitiated ..


- a lambda term may be a variable, x
- if t is a lambda term, and x is a variable, then λx.t is a lambda term (called a lambda abstraction)
- if t and s are lambda terms, then ts is a lambda term (called an application)


and the Scala definitions for the language elements ..

// variable names
type Name = String
  
// Expression types
trait Exp
case class Lit(i: Int) extends Exp
case class Var(n: Name) extends Exp
case class Plus(e1: Exp, e2: Exp) extends Exp
case class Abs(n: Name, e: Exp) extends Exp
case class App(e1: Exp, e2: Exp) extends Exp
  
// Value types
trait Value
case class IntVal(i: Int) extends Value
case class FunVal(e: Env, n: Name, exp: Exp) extends Value

// Environment in which the λ-abstraction will be evaluated
type Env = collection.immutable.Map[Name, Value]

// defining additional data constructors because in Scala
// typeclass resolution and variance often give some surprises with type inferencing

object Values {
  def intval(i: Int): Value = IntVal(i)
  def funval(e: Env, n: Name, exp: Exp): Value = FunVal(e, n, exp)
}

The Reference Implementation

I start with the base implementation, which is a functional model of the interpreter. It contains only the basic stuff for evaluation and has no monadic structure. Incrementally we will start having fun with this ..

def eval0: Env => Exp => Value = { env => exp =>
  exp match {
    case Lit(i) => IntVal(i)
    case Var(n) => (env get n).get
    case Plus(e1, e2) => {
      val IntVal(i1) = eval0(env)(e1)
      val IntVal(i2) = eval0(env)(e2)
      IntVal(i1 + i2)
    }
    case Abs(n, e) => FunVal(env, n, e)
    case App(e1, e2) => {
      val val1 = eval0(env)(e1)
      val val2 = eval0(env)(e2)
      val1 match {
        case FunVal(e, n, exp) => eval0((e + ((n, val2))))(exp)
      }
    }
  }
}

Note we assume that we have the proper matches everywhere - the Map lookup in processing variables (Var) doesn't fail and we have the proper function value when we go for the function application. So things look happy for the correct paths of expression evaluation ..

// Evaluate: 12 + ((λx -> x)(4 + 2))

scala> val e1 = Plus(Lit(12), App(Abs("x", Var("x")), Plus(Lit(4), Lit(2))))
e1: Plus = Plus(Lit(12),App(Abs(x,Var(x)),Plus(Lit(4),Lit(2))))

scala> eval0(collection.immutable.Map.empty[Name, Value])(e1)
res4: Value = IntVal(18)

Go Monadic

Monad transformers give you layers of control over the various aspects of your computation. But for that to happen you need to organize your code in a monadic way. Think of it like this - if your code models the computations of your domain (aka the domain logic) as per the contracts of an abstraction you can very well compose more of similar abstractions in layers without directly poking into the underlying implementation.

Let's do one thing - let's transform the above function into a monadic one that doesn't add any effect. It only sets up the base case for other monad transformers to prepare their playing fields. It's the Identity monad, which simply applies the bound function to its input without any additional computational effect. In scalaz 7 Identity simply wraps a value and provides a map and flatMap for bind.

Here's our next iteration of eval, this time with the Identity monad baked in .. eval0 was returning Value, eval1 returns Identity[Value] - the return type makes this fact explicit that we are now in the land of monads and have wrapped ourselves into a computational structure which can only be manipulated through the bounds of the contract that the monad allows.

type Eval1[A] = Identity[A]

def eval1: Env => Exp => Eval1[Value] = {env => exp =>
  exp match {
    case Lit(i) => intval(i).point[Eval1]
    case Var(n) => (env get n).get.point[Eval1]
    case Plus(e1, e2) => for {
      i <- eval1(env)(e1)
      j <- eval1(env)(e2)
    } yield {
      val IntVal(i1) = i
      val IntVal(i2) = j
      IntVal(i1 + i2)
    }
    case Abs(n, e) => funval(env, n, e).point[Eval1]
    case App(e1, e2) => for {
      val1 <- eval1(env)(e1)
      val2 <- eval1(env)(e2)
    } yield {
      val1 match {
        case FunVal(e, n, exp) => eval1((e + ((n, val2))))(exp)
      }
    }
  }
}

All returns are now monadic, though the basic computation remains the same. The Lit, Abs and the Var cases use the point function (pure in scalaz 6) equivalent to a Haskell return. Plus and App use the for comprehension to evaluate the monadic action. Here's the result on the REPL ..

scala> eval1(collection.immutable.Map.empty[Name, Value])(e1)
res7: Eval1[Value] = scalaz.Identity$$anon$2@18f67fc

scala> res7.value
res8: Value = IntVal(18)

So the Identity monad has successfully installed itself making our computational model like an onion peel on which we can now stack up additional effects.

Handling Errors

In eval1 we have a monadic functional model of our computation. But we have not yet handled any errors that may arise from the computation. And I promised that we will add such effects incrementally without changing the guts of your model.

As a very first step, let's use a monad transformer that helps us handle errors, not by throwing exceptions (exceptions are bad .. right?) but by wrapping the error conditions in yet another abstraction. Needless to say this also has to be monadic because we would like it to compose with our already implemented Identity monad and the others that we will work out later on.

scalaz 7 offers EitherT which we can use as the Error monad transformer. It is defined as ..

sealed trait EitherT[A, F[_], B] {
  val runT: F[Either[A, B]]
  //..
}

It adds the EitherT computation on top of F so that the composed monad will have both the effects. And as with Either we use the Left A for the error condition and the Right B for returning the result. The plus point of using the monad transformer is that this plumbing of the 2 monads is taken care of by the implementation of EitherT, so that we can simply define the following ..

type Eval2[A] = EitherT[String, Identity, A]

def eval2a: Env => Exp => Eval2[Value] = {env => exp =>
  //..
}

The error will be reported as String and the Value will be returned in the Right constructor of Either. Our return type is also explicit in what the function does. You can simply change the return type to Eval2 and keep the rest of the function same as eval1. It works perfectly like the earlier one. Since we have not yet coded explicitly for the error conditions, appropriate error messages will not appear, but the happy paths execute as earlier even with the changed return type. This is because Identity was a monad and so is the newly composed one consisting of Identity and EitherT.

We can run eval2a and the only difference in output will be that the result will be wrapped in a Right constructor ..

scala> val e1 = Plus(Lit(12), App(Abs("x", Var("x")), Plus(Lit(4), Lit(2))))
e1: Plus = Plus(Lit(12),App(Abs(x,Var(x)),Plus(Lit(4),Lit(2))))

scala> eval2a(collection.immutable.Map.empty[Name, Value])(e1)
res31: Eval2[Value] = scalaz.EitherTs$$anon$2@ad2f60

scala> res31.runT.value
res33: Either[String,Value] = Right(IntVal(18))

We can do a couple of more iterations improving upon how we can handle errors using EitherT and issue appropriate error messages to the user. Here's the final version that has all error handling implemented. Note however that the core model remains the same - we have only added the Left handling for error conditions ..

def eval2: Env => Exp => Eval2[Value] = {env => exp =>
  exp match {
    case Lit(i) => intval(i).point[Eval2]

    case Var(n) => (env get n).map(v => rightT[String, Identity, Value](v))
                              .getOrElse(leftT[String, Identity, Value]("Unbound variable " + n))
    case Plus(e1, e2) => 
      val r = 
        for {
          i <- eval2(env)(e1)
          j <- eval2(env)(e2)
        } yield((i, j))

      r.runT.value match {
        case Right((IntVal(i_), IntVal(j_))) => rightT(IntVal(i_ + j_))
        case Left(s) => leftT("type error in Plus" + "/" + s)
        case _ => leftT("type error in Plus")
      }

    case Abs(n, e) => funval(env, n, e).point[Eval2]

    case App(e1, e2) => 
      val r =
        for {
          val1 <- eval2(env)(e1)
          val2 <- eval2(env)(e2)
        } yield((val1, val2))

      r.runT.value match {
        case Right((FunVal(e, n, exp), v)) => eval2(e + ((n, v)))(exp)
        case _ => leftT("type error in App")
      }
  }
}

How about some State ?

Let's add some mutable state in the function using the State monad. So now we need to stack up our pile of transformers with yet another effect. We would like to add some profiling capabilities that track invocation of every pattern in the evaluator. For simplicity we just count the number of invocations as an integer and report it along with the final output. We define the new monad by wrapping a StateT constructor around the innermost monad, Identity. So now our return type becomes ..

type Eval3[A] = EitherT[String, StateTIntIdentity, A]

We layer the StateT between EitherT and Identity - hence we need to form a composition between StateT and Identity that goes as the constructor to EitherT. This is defined as StateTIntIdentity, we make the state an Int. And we define this as a type lambda as follows ..

type StateTIntIdentity[α] = ({type λ[α] = StateT[Int, Identity, α]})#λ[α]

Intuitively our returned value in case of a successful evaluation will be a tuple2 (Either[String, Value], Int), as we will see shortly.

We write a couple of helper functions that manages the state by incrementing a counter and lifting the result into a StateT monad and finally lifting everything into the EitherT.

def stfn(e: Either[String, Value]) = (s: Int) => id[(Either[String, Value], Int)](e, s+1)

def eitherNStateT(e: Either[String, Value]) =
  eitherT[String, StateTIntIdentity, Value](stateT[Int, Identity, Either[String, Value]](stfn(e)))

And here's the eval3 function that does the evaluation along with profiling and error handling ..

def eval3: Env => Exp => Eval3[Value] = {env => exp => 
  exp match {
    case Lit(i) => eitherNStateT(Right(IntVal(i)))

    case Plus(e1, e2) =>
      def appplus(v1: Value, v2: Value) = (v1, v2) match {
        case ((IntVal(i1), IntVal(i2))) => eitherNStateT(Right(IntVal(i1 + i2))) 
        case _ => eitherNStateT(Left("type error in Plus"))
      }
      for {
        i <- eval3(env)(e1)
        j <- eval3(env)(e2)
        v <- appplus(i, j)
      } yield v

    case Var(n) => 
      val v = (env get n).map(Right(_))
                         .getOrElse(Left("Unbound variable " + n))
      eitherNStateT(v)

    case Abs(n, e) => eitherNStateT(Right(FunVal(env, n, e)))

    case App(e1, e2) => 
      def appfun(v1: Value, v2: Value) = v1 match {
        case FunVal(e, n, body) => eval3(e + ((n, v2)))(body)
        case _ => eitherNStateT(Left("type error in App"))
      }

      val s =
        for {
          val1 <- eval3(env)(e1)
          val2 <- eval3(env)(e2)
          v    <- appfun(val1, val2)
        } yield v

      val ust = s.runT.value.usingT((x: Int) => x + 1)
      eitherT[String, StateTIntIdentity, Value](ust)
  }
}

We run the above function through another helper runEval3 that also takes the seed value of the state ..

def runEval3: Env => Exp => Int => (Either[String, Value], Int) = { env => exp => seed => 
  eval3(env)(exp).runT.value.run(seed)
}

Here's the REPL session with runEval3 ..

scala> val e1 = Plus(Lit(12), App(Abs("x", Var("x")), Plus(Lit(4), Lit(2))))
e1: Plus = Plus(Lit(12),App(Abs(x,Var(x)),Plus(Lit(4),Lit(2))))
scala> runEval3(env)(e1)(0)
res25: (Either[String,Value], Int) = (Right(IntVal(18)),8)

// -- failure case --
scala> val e2 = Plus(Lit(12), App(Abs("x", Var("y")), Plus(Lit(4), Lit(2))))
e2: Plus = Plus(Lit(12),App(Abs(x,Var(y)),Plus(Lit(4),Lit(2))))

scala> runEval3(env)(e2)(0)
res27: (Either[String,Value], Int) = (Left(Unbound variable y),7)

In case you are interested the whole code base is there in my github repository. Feel free to check out. I will be adding a couple of more transformers for hiding the environment (ReaderT) and logging (WriterT) and also IO.

6 comments:

Jason Zaugg said...

There is no need to use the type lambda to define StateTIntIdentity; they are useful to define anonymous type functions; but not needed if you're giving your type function a name.


scala> trait M[A, B, C]
defined trait M

scala> type MIntInt[A] = M[Int, Int, A]
defined type alias MIntInt

scala> type MIntInt_[A] = ({type lam[a]=M[Int, Int, a]})#lam[A]
defined type alias MIntInt_

scala> implicitly[MIntInt[String] =:= MIntInt_[String]]
res0: =:=[MIntInt[String],MIntInt_[String]] =

Jack said...

Thanks for your code..

Anonymous said...

This is a really nice post. Monad Transformers weren't a hard concept to grasp after learning monads, but when and how to use them were. Your non trivial but small example helps a lot!

Anonymous said...

Thanks for the nice write-up!

Jason Scott said...

transLift seems to have disappeared in scalaz 7.1
How would you do this now?

Phillip Henry said...

Great post and beautifully written, Debasish. Very clear.