Tuesday, March 17, 2015

Abstract Binding Trees, an addendum

It struck me after the last post that it might be helpful to give an example using abstract binding trees in a more nontrivial way. The pure lambda calculus has a very simple binding structure, and pretty much anything you do will work out. So I decided to show how ABTs can be used to easily support a much more involved form of binding -- namely, pattern matching.

This makes for a nice example, because it is a very modular and well-structured language feature, that naturally has a rather complex binding structure. However, as we will see, it very easily into an ABT-based API. I will do so assuming the implementation in the previous post, focusing only on the use of ABTs.

As is usual, we will need to introduce a signature module for the language.

 module Lambda =
 struct
Since pattern matching really only makes sense with a richer set of types, let's start by adding sums and products to the type language.
   type tp = Arrow of tp * tp | One | Prod of tp * tp | Sum of tp * tp 

Next, we'll give a datatype of patterns. The PWild constructor is the wildcard pattern $\_$, the PVar constructor is a variable pattern, the PPair(p, p') constructor is the pair pattern, and the Inr p and Inl p patterns are the patterns for the left and right injections of the sum type.

   type pat = PWild | PVar | PUnit | PPair of pat * pat 
            | PInl of pat | PInr of pat 

Now, we can give the datatype for the language signature itself. We add pairs, units, and sums to the language, as well as a case expression which takes a scrutinee, and a list of branches (a list of patterns and the corresponding case arms).

   type 'a t = Lam of 'a | App of 'a * 'a | Annot of tp * 'a
             | Unit | Pair of 'a * 'a | Inl of 'a | Inr of 'a
             | Case of 'a * (pat * 'a) list

One thing worth noting is that the datatype of patterns has no binding information in it at all. The basic idea is that if we will represent (say) a pair elimination $\letv{(x,y)}{e}{e'}$ as the constructor Case(e, [PPair(PVar, PVar), Abs(x, Abs(y, e'))]) (suppressing useless Tm constructors). So the PVar constructor is merely an indication that the term be an abstraction, with the number of abstractions determined by the shape of the pattern. This representation is first documented in Rob Simmons's paper Structural Focalization.

This is really the key thing that gives the ABT interface its power: binding trees have only one binding form, and we never introduce any others.

Now we can give the map and join operations for this signature.

 
   let map f = function
     | Lam x        -> Lam (f x)
     | App (x, y)   -> App(f x, f y)
     | Annot(t, x)  -> Annot(t, f x)
     | Unit         -> Unit 
     | Pair(x, y)   -> Pair(f x, f y)
     | Inl x        -> Inl (f x)
     | Inr x        -> Inr (f x)
     | Case(x, pys) -> Case(f x, List.map (fun (p, y) -> (p, f y)) pys)

   let join m = function
     | Lam x -> x
     | App(x, y) -> m.join x y
     | Annot(_, x) -> x
     | Unit        -> m.unit
     | Pair(x,y)   -> m.join x y
     | Inl x       -> x
     | Inr x       -> x
     | Case(x, pys) -> List.fold_right (fun (_, y) -> m.join y) pys x 
 end

As usual, we construct the syntax by applying the Abt functor.

 
 module Syntax = Abt(Lambda)

We can now define a bidirectional typechecker for this language. Much of the infrastructure is the same as in the previous post.

 
 module Bidir = struct
   open Lambda
   open Syntax
  type ctx = (var * tp) list

We do, however, extend the is_check and is_synth functions to handle the new operations in the signature. Note that case statements are viewed as checking forms.

  let is_check = function 
    | Tm (Lam _) | Tm Unit | Tm (Pair(_, _))
    | Tm (Inl _) | Tm (Inr _) | Tm (Case(_, _))-> true
    | _ -> false
  let is_synth e = not(is_check e)

  let unabs e =
    match out e with
    | Abs(x, e) -> (x, e)
    | _ -> assert false

When we reach the typechecker itself, most of it --- the check and synth functions --- are unchanged. We have to add new cases to check to handle the new value forms (injections, pairs, and units), but they are pretty straightforward.

  let rec check ctx e tp =
    match out e, tp with
    | Tm (Lam t), Arrow(tp1, tp') ->
      let (x, e') = unabs t in
      check ((x, tp1) :: ctx) e' tp'
    | Tm (Lam _), _               -> failwith "expected arrow type"
    | Tm Unit, One  -> ()
    | Tm Unit, _ -> failwith "expected unit type"
    | Tm (Pair(e, e')), Prod(tp, tp') ->
      let () = check ctx e tp in
      let () = check ctx e' tp' in
      ()
    | Tm (Pair(_, _)), _ -> failwith "expected product type"
    | Tm (Inl e), Sum(tp, _) -> check ctx e tp
    | Tm (Inr e), Sum(_, tp) -> check ctx e tp
    | Tm (Inl _), _
    | Tm (Inr _), _          -> failwith "expected sum type"

The big difference is in checking the Case form. Now, we need to synthesize a type for the scrutinee, and then check that each branch is well-typed. This works by calling the check_branch function on each branch, passing it the pattern, arm, type and result type as arguments. For reasons that will become apparent, the pattern and its type are passed in as singleton lists. (I don't do coverage checking here, only because it doesn't interact with binding in any way.)

    | Tm (Case(e, branches)), tp ->
      let tp' = synth ctx e in
      List.iter (fun (p,e) -> check_branch ctx [p] e [tp'] tp) branches 
    | body, _ when is_synth body ->
      if tp = synth ctx e then () else failwith "Type mismatch"
    | _ -> assert false

  and synth ctx e =
    match out e with
    | Var x -> (try List.assoc x ctx with Not_found -> failwith "unbound variable")
    | Tm(Annot(tp, e)) -> let () = check ctx e tp in tp 
    | Tm(App(f, e))  ->
      (match synth ctx f with
       | Arrow(tp, tp') -> let () = check ctx e tp in tp'
       | _ -> failwith "Applying a non-function!")
    | body when is_check body -> failwith "Cannot synthesize type for checking term"
    | _ -> assert false
The way that branch checking works is by steadily deconstructing a list of patterns (and their types) into smaller lists.
  and check_branch ctx ps e tps tp_result =
    match ps, tps with

If there are no more patterns and their types, we are done, and can check that the arm has the right type.

    | [], []-> check ctx e tp_result

If we have a variable pattern, we unabstract the arm, and bind that variable to type, and recur on the smaller list of patterns and types. Note that this is the only place we have to do anything at all with binding, and it's trivial!

    | (PVar :: ps), (tp :: tps)
      -> let (x, e) = unabs e in
         check_branch ((x, tp) :: ctx) ps e tps tp_result

Wildcards and unit patterns work the same way, except that they don't bind anything.

    | (PWild :: ps), (tp :: tps)
      -> check_branch ctx ps e tps tp_result
    | (PUnit :: ps), (One :: tps)
      -> check_branch ctx ps e tps tp_result
    | (PUnit :: ps), (tp :: tps)
      -> failwith "expected term of unit type"

Pair patterns are deconstructed into two smaller types, and the product type they are checked against is broken into its two subterms, and then the list is lengthened with all of the subderivations.

    | (PPair(p, p') :: ps), (Prod(tp, tp') :: tps) 
      -> check_branch ctx (p :: p' :: ps) e (tp :: tp' :: tps) tp_result
    | (PPair(p, p') :: ps), (_ :: tps) 
      -> failwith "expected term of product type"

Sum patterns work by recurring on the sub-pattern, dropping the left or right part of the sum type, as appropriate.

    | (PInl p :: ps), (Sum(tp, _) :: tps)
      -> check_branch ctx (p :: ps) e (tp :: tps) tp_result
    | (PInl p :: ps), (_ :: tps)
      -> failwith "expected sum type"
    | (PInr p :: ps), (Sum(_, tp) :: tps)
      -> check_branch ctx (p :: ps) e (tp :: tps) tp_result
    | (PInr p :: ps), (_ :: tps)
      -> failwith "expected sum type"
    | _ -> assert false
end

That's it! So I hope it's clear that ABTs can handle complex binding forms very gracefully.

No comments:

Post a Comment