A few months ago, I began work on a parser generator for fun. This work is being done in Python (my daily language for many reasons), but occasionally I find myself at odds with the lack of certain functionality in the language. In this particular instance, I desperately wanted a pattern-match syntax with “compile-time” exhaustiveness guarantees. Instead of even looking to find an existing solution, I launched right into developing it. I recently re-opened the project and found myself re-reading the code to implement this, and figured I’d write down my solution here in detail.

Pattern Matching

In many (most? all?) functional programming languages, there exists a capability called pattern matching. The basic notion is that you can take an input value which may have one of a number of forms, and you can essentially build a switch over the possibilities.

But more than that, pattern matching allows you to destructure those values, giving names to the internal parts and then using them in the corresponding block of code.

For example, consider the following algebraic datatype in Haskell:

1
2
3
4
data BTree =
    | Branch Int BTree BTree
    | Leaf Int
    | Empty

This type represents binary trees containing integers, where each branch or leaf holds a value.

We could write a function to sum all the integers in such a tree using recursion and pattern matching:

1
2
3
4
5
sumTree :: BTree -> Int
sumTree t = case t of
    | Branch val left right -> val + (sumTree left) + (sumTree right)
    | Leaf val              -> val
    | Empty                 -> 0

This is an example of a pattern match using destructuring. The match syntax (the case ___ of bit followed by the | ___ -> ___ clauses) allows for easily handling the cases of an algebraic data type, and you instantly gain the ability to name the fields for use in the case clause.

Better still is that Haskell (and other languages that support pattern matching) can perform exhaustiveness checks at compile-time. What this means is that if I had mistakenly left out the Empty case in the above function definition, the compiler would have complained about it because an Empty value could be passed in and the not handled! This guarantee is incredibly useful, because if you ever alter the original datatype definition (as happens during development), you are instantly given the ability to find out where you need to make adjustments in the rest of your code for free.

Python’s Problem

The problem for Python is that it simply doesn’t support this. There is no functionality in the language for pattern matching.

But I wanted it. See, I was transcribing some code which had been written in Racket using pattern matches, and in this case I think the pattern-match syntax really is easier to read than using object-oriented programming. So I set out to devise my own pattern-match syntax for Python.

I had some criteria, though:

  • It must be shorter than writing out manual checks.
  • It must be (relatively) easy to read.
  • It must perform exhaustiveness checking at “compile-time” (more on this later).
  • It must support destructuring, or something very like it.

The Solution

It took me a while, but I think I’ve arrived at a solution that checks all the boxes (though I don’t claim it’s the most beautiful thing to use). The code may update over time (I hope), so be sure to check out the current version on GitHub here.

Use is relatively straightforward! Here’s an example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from .match import match

from dataclasses import dataclass

# First, define the classes we'll match over.
@dataclass
class BTree:
  pass

@dataclass
class Branch(BTree):
  val: int
  left: BTree
  right: BTree

@dataclass
class Leaf(BTree):
  val: int

@dataclass
class Empty(BTree):
  pass

# Then build a match function.
sum_tree = match({
  Branch: lambda _, val, left, right: val + sum_tree(left) + sum_tree(right),
  Leaf:   lambda _, val:              val,
  Empty:  lambda _:                   0,
}, BTree)

# Lastly, a test!
assert sum_tree(Branch(2, (Branch (1, Leaf(4), Empty())), Leaf(7))) == 14

Here’s the current version, with comments and docstring removed and a few areas marked with numbers so we can talk about them more easily:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# (1)
def match(table: Dict[Type, Callable[..., Val]],
          base: Optional[Type] = None,
          params: Optional[Tuple[str, ...]] = None,
          pos: int = 0,
          exhaustive: bool = True,
          omit: Optional[Set[Type]] = None,
          omit_recursive: bool = False,
          same_module_only: bool = True,
          destructure: bool = True
         ) -> Callable[..., Val]:
  # (2)
  _caller: Traceback = getframeinfo(stack()[1][0])
  _mdfn = _caller.filename  # MDFN = Match Definition File Name.
  _mdln = _caller.lineno  # MDLN = Match Definition Line Number.

  if params is None:
    params = ('_match_obj',)
  if omit is None:
    omit = set()
  if base is None and exhaustive:
    raise MatchDefinitionError(_mdfn, _mdln, "Cannot perform exhaustive match without a given base class.")

  funcs: Dict[Type,
              Tuple[Callable[..., Val],
                    Dict[str,
                         Callable[[List[Any], Any], Any]]]] = {}
  subclasses: Dict[Type, bool] = {}

  # (3)
  def get_all_subclasses(cls: Type):
    for subclass in cls.__subclasses__():
      if subclass not in omit:
        subclasses[subclass] = False
      if not omit_recursive:
        get_all_subclasses(subclass)

  if base is not None:
    get_all_subclasses(base)

  #(4)
  for t, f in table.items():
    if base is not None:
      if t not in subclasses:
        raise InvalidClausePatternError(_mdfn, _mdln, base, t)
      subclasses[t] = True
    func_params = list(params)
    if destructure:
      annotations: Dict[str, Any] = t.__dict__.get('__annotations__', {})
      func_params.extend(annotations.keys())
    else:
      annotations = {}
    duped_names = {name for name in params if name in annotations}
    if duped_names:
      raise DuplicatedNamesError(_mdfn, _mdln, list(duped_names))
    sig_params = signature(f).parameters.keys()
    if len(sig_params) != len(func_params):
      extra_params   = [param for param in sig_params
                        if param not in func_params]
      missing_params = [param for param in func_params
                        if param not in sig_params]
      min_len = min(len(extra_params), len(missing_params))
      extra_params = extra_params[min_len:]
      missing_params = missing_params[min_len:]
      _, src_ln = findsource(f)
      src_ln += 1  # Line numbers from `findsource` are 0-indexed.
      raise ClauseSignatureError(_mdfn, _mdln, t, src_ln, extra_params, missing_params)
    getters: Dict[str, Callable[[List[Any], Any], Any]] = {}
    for i, name in enumerate(sig_params):
      if i < len(params):
        getters[name] = ParamGetFunc(i)
      else:
        getters[name] = AttrGetFunc(name)
    funcs[t] = (f, getters)

  # (5)
  if exhaustive:
    missing_subclasses: Set[Type] = {cls for cls in subclasses
                                     if not subclasses[cls]}
    if same_module_only:
      missing_subclasses = {cls for cls in missing_subclasses
                            if cls.__module__ == base.__module__}
    if missing_subclasses:
      raise NonExhaustiveMatchError(_mdfn, _mdln, missing_subclasses)

  # (6)
  def do_match(*args: Any) -> Val:
    x = args[pos]
    cls = x.__class__
    fgs = funcs.get(cls)
    if fgs is None:
      raise NoMatchError(_mdfn, _mdln, cls)
    func, gs = fgs
    call_params: Dict[str, Any] = {}
    for name in gs:
      try:
        val = gs[name](args, x)
      except Exception:
        raise MatchError(_mdfn, _mdln, f"Could not obtain value via getter for param {name}.")
      call_params[name] = val
    return func(**call_params)

  # (7)
  return do_match

Let’s step through it one piece at a time! I’ll try to keep it relatively brief, though.

1. Function signature

The function is intended to be used in a very particular way. In most cases, you should only pass in two arguments (table and base). Occasionally, params will be useful. Most of the other options should probably be left alone most of the time.

The match function is used at the top-level to define another function. That is, you assign a value to the result of calling match(...), and the value produced is a function which will perform the pattern matching. For an example, just look up a bit!

2. Setup

The function starts by performing some setup.

First, the current frame is obtained. This allows for determining the filename and line number where the match is declared, which helps with error messages.

Then, some of the parameters are checked for None values. params represents the list of parameters to the base function; these will be passed along to each clause’s lambda in order. omit determines whether there are any specific classes that should be omitted from exhaustiveness checking, which can be useful if there are abstract classes in the middle of the class hierarchy. And base is the base class that you’re performing the match over, which is needed if the match function is performing exhaustiveness checks (which it does by default).

Lastly, the funcs and subclasses dictionaries are initialized. funcs will contain the functions used by the returned function at the end, and subclasses is used for checking exhaustiveness.

3. get_all_subclasses

The get_all_subclasses method is then defined and called (if the base value is not None). This function looks through the subclasses of the base class and builds the list of classes we expect to see in the table. By default, all subclasses are included, which means this will recursively inspect the classes until their hierarchy is fully walked.

4. Pre-processing

Next, match iterates over the passed-in table, performing a number of checks and preparing the final dispatch table along the way. Let’s take a look at each important piece. (In the iteration, t is short for type and f for function.)

The first check is simple: if a base value was supplied, we validate that each class in the table is a recognized subclass of that base class. The subclasses dictionary is updated to reflect that we saw this subclass in the table.

Next, the annotations of the current class are read. The function corresponding to each clause must provide sufficient arguments to handle both the params as well as all the fields of the subclass. The annotations are used for that second part. (Note that this requires that the subclasses are all implemented using @dataclass. Perhaps I’ll relax that requirement at some point.)

Then we check to ensure that there’s no overlap between the parameter names given in params and the fields of the subclass. This is because the object fields will be retrieved by name based on the signature of the function, and the function must include fields for all the parameters in params. So we can’t have any overlap in the names.

The signature of the clause function is then inspected. This gives us a list of expected parameter names. The length of this list is then compared to the length of the list of expected parameters. If there’s a mismatch (indicating that either the function has too few or too many fields), an error is raised that labels which clause is to blame (with line number) and which fields are extra or missing.

Lastly, getter functions are built. These functions are used to map the match function parameters and passed-in object fields to the parameters of the clause function efficiently. The getter functions work either over parameters to the whole match function (ParamGetFunc) or else the fields of the passed-in matching object (AttrGetFunc).

5. Checking for exhaustiveness

At this point, the exhaustiveness check is pretty straightforward: the subclasses dictionary will have False entries for subclasses which do not have clauses in the table.

By default, match only checks for exhaustiveness within a single module. This means that if you declare a base class in one file and all the implementations in separate files, you’ll need to set same_module_only to False. Otherwise, this portion of the code will not trigger an exception.

6. Building the real match function

Of course, we want (some) performance! Instead of performing these exhaustiveness checks at run-time, they are checked at “compile-time” (i.e., when the defining module is first loaded). Once we’ve checked exhaustiveness statically in this manner, we can skip them in the future.

To that end, the match function produces a new function to perform matching and execute the clauses at run-time. Section (6) builds that function.

First, the matching object (x) is identified by position. (By default, it is assumed to be the first argument.) The class is extracted and used to look up the function-getters we created back in the end of (4). Of course, if there’s no match then an error is raised.

The object is then destructured using the getters that were defined. ParamGetFunc parameters are obtained via indexing into the args, and AttrGetFunc parameters by using getattr on the matching object. If any parameter can’t be found, an error is raised. These parameters are all stored in a dictionary with their name from the clause function.

Finally, the clause function is called with all of the necessary parameters.

7. Returning the function

The last thing done by match is returning the real match function we just created in (6).

The result is a comparatively quick pattern match function!

In Action

Although I wrote it above, let’s look again at the way this is used:

1
2
3
4
5
sum_tree = match({
  Branch: lambda _, val, left, right: val + sum_tree(left) + sum_tree(right),
  Leaf:   lambda _, val:              val,
  Empty:  lambda _:                   0,
}, BTree)

match is given a table which maps the subclasses to functions. These functions must take (a) the parameters of the function (usually just the matching object) and (b) the fields of the object being matched. Since this example has no need of looking at the passed-in object in any case, the first field of each lambda in this example is _ (to indicate that it can be thrown away).

The function defined in each lambda computes a result. I have found that if this gets too complex, it can be worth factoring out another function and then simply calling that inside the lambda. It’s not perfect, but it works!

Note that the parameters of the lambda must align with the attributes of the matched object. I’m considering changing this so it’s done by position instead of by name, but I first have to determine whether the order of defined attributes in a dataclass instance can be relied-upon, and then I have to decide if it’s worth requiring all subclasses that are matched-on to be instances of dataclass.

At the end, match is passed in BTree. This is the base class that the exhaustiveness checks will be performed relative to. This can be left out, but then there can’t be any exhaustiveness guarantees. I may also change this; all of the subclasses in the tree can be used to walk a class hierarchy and find a least common ancestor in that hierarchy. But this will be left to the future for now.

Room for Improvement

Of course, my solution isn’t perfect.

There are a lot of assumptions in this code about how Python works, from inspecting the call stack to hoping no subclasses are co-dependent somehow. I think there’s plenty of ways I can improve this function over time, but for the moment I’m happy with it.

I also have a predicated match function (defined in the same module as match_pred). This function uses a more complicated table, but allows for the construction of guarded clauses — similar to the use of where in Haskell. However, this is currently implemented very simply and is not too robust, unlike the more fully-featured match I just explained. I intend to update pred_match to work more like match in the future!