Implementing Do Notation in Python
a.k.a. Evil hacks in python for fun and profit
disclaimer: This is all a really bad idea, but it was pretty fun to do. Hopefully I can teach people how to do equally horrible-yet-fun things. You should not actually use the decorator I built.
If you just want to see it in action, go to https://github.com/imh/python_do_notation.
Monads are a really useful as an abstraction, and a nice clean syntax goes a long way in making them pain-free to use.
In Haskell, that’s do
notation.
Scala has a similar for
notation.
I spend most of my time in python, though, and it doesn’t have anything analogous.
Can we make a pain-free equivalent in python?
We want to find a way to make something that looks nice act like a sequence of lambdas chained to each other with bind statements. If you squint really hard and turn your head, taking one piece of code and making it do something else is the problem that decorators solve, and as it turns out, you can make a decorator to do pretty much anything in python. I’ll start with a decorator overview for those who haven’t used them. Skip this section if you’re comfortable building these already.
Python decorators
A python decorator is a function that takes a function and returns a function. I’ll use the example of a function that logs the input and return value of another function.
def log_input_and_output(f):
def sneaky_replacement_function(x):
print 'INPUT: ', x
y = f(x)
print 'OUTPUT: ', y
return y
return sneaky_replacement_function
This function takes a univariate function f
and returns another function that just calls f
, but does other stuff besides (printing its input and output).
We could use it to log some math like so:
def double_x(x):
return x*2
double_and_log_x = log_input_and_output(double_x)
double_and_log_x(5) # returns 10
# >> INPUT: 5
# >> OUTPUT: 10
We took a function double_x
and passed it to log_input_and_output
, thereby getting a new function that does all the stuff double_x
does, and then some.
A lot of times, we aren’t actually interested in the function we’re defining on its own. We only ever want the logged version. In this case we could overwrite it like so:
def double_x(x):
return x*2
double_x = log_input_and_output(double_x)
Now we only have the version we want. Python actually has a special way to do that to make it more convenient:
@log_input_and_output
def double_x(x):
return x*2
For almost all purposes, it’s equivalent to the code block above, but it looks much nicer and you can chain them together easily.
Functions of functions
So there you have it: decorators are a nice way of modifying a function you wrote and turning it into a function you didn’t. That’s just what we’re trying to do to bastardize a new monad syntax! It sounds close enough, so let’s smash this round peg into the square hole. Let’s try to build a decorator that allows something along the lines of:
@with_do_notation
def my_function(x):
y = do:
a <- some_monadic_function(x)
b <- more_stuff(a)
mreturn(a + b)
return y
With any python decorator, we get a function and return a function, but all we’re supposed to do is call it in clever ways. We can call it with different arguments, or do evil things with the output, but in the end, we still just have this black box inside. What we really need to do to add a new syntax is to break open the black box and muck around with the internals, turning it into valid python before (hopefully) packing it back up again. Round pegs don’t fit into square holes, so we need to whip out our hacksaw.
We have to stick with valid python
No matter what the decorator with_do_notation
is, if the function we’re decorating isn’t syntactically valid python and it’s just going to error.
I don’t want to write a new language that compiles to python, so I have to obey python’s syntax rules.
That means we’re searching for valid python that looks like a block of code with some semblance of associating a name to it.
It also can’t be something someone’s going to accidentally use, since we’re overriding its intended meaning.
For binding, I decided to replace haskell’s <-
with plain old equals =
.
For the block itself, with x as y:
came to mind:
@with_do_notation
def my_function(x):
with do as y:
a = some_monadic_function(x)
b = more_stuff(a)
mreturn(a + b)
return y
Finally, there’s no type inference, so I had to hint out what the type of mreturn
should be:
@with_do_notation
def my_function(x):
with do(Maybe) as y:
a = some_monadic_function(x)
b = more_stuff(a)
mreturn(a + b)
return y
It’s not the prettiest, but it doesn’t throw an exception. It’s clean enough that I’m happy.
The ast library is a lovely little hacksaw
Back to our earlier problem, we have to crack open the black box that is my_function
and rewrite the with
block inside into a sequence of binds.
For that, I need a way to inspect and modify existing python functions.
Anyone who has abused __dict__
knows that python is very friendly to this kind of thing.
It’s actually scary how friendly this language is to doing things that feel dirty.
Do it too much and you’ll start writing blog posts like this one.
For this particular use case, I found the ast
library.
It gives you tools to work with python code, traversing and modifying the AST as needed.
The official docs aren’t that great for jumping in, but there’s another set of docs called Green Tree Snakes that helped me get started.
Parsing
With the inspect
and ast
libraries, we can reconstruct the AST for our my_function
function:
# inspect.getsource just returns the source code for the function
# textwrap.dedent unindents it, so we can parse it alone
src = dedent(inspect.getsource(my_function))
# Use the ast library to parse the function definition
module = ast.parse(src)
# pretty print the module with this: https://bitbucket.org/takluyver/greentreesnakes/src/default/astpp.py?fileviewer=file-view-default
# "print dump(module)" returns this:
Module(body=[
FunctionDef(name='my_function', args=arguments(args=[
Name(id='x', ctx=Param()),
], vararg=None, kwarg=None, defaults=[]), body=[
With(context_expr=Call(func=Name(id='do', ctx=Load()), args=[
Name(id='Maybe', ctx=Load()),
], keywords=[], starargs=None, kwargs=None), optional_vars=Name(id='y', ctx=Store()), body=[
Assign(targets=[
Name(id='a', ctx=Store()),
], value=Call(func=Name(id='some_monadic_function', ctx=Load()), args=[
Name(id='x', ctx=Load()),
], keywords=[], starargs=None, kwargs=None)),
Assign(targets=[
Name(id='b', ctx=Store()),
], value=Call(func=Name(id='more_stuff', ctx=Load()), args=[
Name(id='a', ctx=Load()),
], keywords=[], starargs=None, kwargs=None)),
Expr(value=Call(func=Name(id='mreturn', ctx=Load()), args=[
BinOp(left=Name(id='a', ctx=Load()), op=Add(), right=Name(id='b', ctx=Load())),
], keywords=[], starargs=None, kwargs=None)),
]),
Return(value=Name(id='y', ctx=Load())),
], decorator_list=[]),
])
That’s the AST for my_function
.
Its bulky, but you can recognize its original definition from above in it.
We can work with that.
Rewriting the AST
The goal here is to find any With
nodes and rewrite their contents.
ast
makes this super easy with ast.NodeTransformer
.
If you inherit from NodeTransformer
, you can override the visit_Foo
method to do something special to the Foo
nodes.
In this case, we want to override the With
block:
class RewriteWithDo(NodeTransformer):
def visit_With(self, node):
self.generic_visit(node)
# Make sure its context expression is a function called "do"
if not (hasattr(node.context_expr, 'func') and
node.context_expr.func.id == 'do'):
return node
name = node.optional_vars.id
# The argument of the "do" function is the name of the monad class.
monad = node.context_expr.args[0].id
bind_chain = rewrite_with_to_binds(node.body, monad)
# Assign the result of the bind chain to the name in
# with do(...) as name:
return Assign(targets=[Name(id=name, ctx=Store())],
value=bind_chain)
We check to make sure we’re only rewriting nodes for the With
blocks of the form with do(MyClass) as my_name:
, then rewrite the body into a sequence of monadic binds:
class RewriteDoBody(NodeTransformer):
def __init__(self, monad):
self.monad = monad
super(RewriteDoBody, self).__init__()
def visit_Call(self, node):
self.generic_visit(node)
if not (isinstance(node.func, Name) and
node.func.id == 'mreturn'):
return node
node.func = Attribute(value=Name(id=self.monad, ctx=Load()), attr='mreturn', ctx=Load())
return node
# TODO allow let bindings in do block
def rewrite_with_to_binds(body, monad):
new_body = []
# Construct a transformer for this specific monad's mreturn
rdb = RewriteDoBody(monad)
# This is the body of the lambda we're about to construct
last_part = body[-1].value
# Rewrite mreturn
rdb.visit(last_part)
# Iterate in reverse, making each line the into a lambda whose body is the
# rest of the lines (which are each lambdas), and whose names are the
# bind assignments.
for b in reversed(body[:-1]):
rdb.visit(b)
if isinstance(b, Assign):
name = b.targets[0].id
value = b.value
else :
# If there was no assignment to the bind, just use a random name, eek
name = '__DO_NOT_NAME_A_VARIABLE_THIS_STRING__'
value = b.value
# last part = value.bind(lambda name: last_part)
last_part = Call(func=Attribute(value=value, attr='bind', ctx=Load()),
args=[Lambda(args=arguments(args=[Name(id=name, ctx=Param()),],
vararg=None, kwarg=None, defaults=[]),
body=last_part),],
keywords=[], starargs=None, kwargs=None)
return last_part
With these node transformers, we can finally write our decorator:
def with_do_notation(f):
# Get the context for the function we're calling this from
frame = inspect.stack()[1][0]
# Get the function's source
src = dedent(inspect.getsource(f))
# Parse it into an AST
module = parse(src)
function_def = module.body[0]
function_name = function_def.name
assert(isinstance(function_def, FunctionDef))
# Rewrite any `with do(MyMonadInstance) as my_name:` blocks
RewriteWithDo().visit(module)
# Remove the with_do_notation decorator, so it doesn' recurse infinitely
# when we compile it
function_def.decorator_list = [d for d in function_def.decorator_list
if not (isinstance(d, Name) and d.id=='with_do_notation')]
# Define the function in the scope it was originally defined, with its
# original name and new body
exec(compile(fix_missing_locations(module),
filename='<ast>', mode='exec'), frame.f_globals, frame.f_locals)
# Return the new function
return eval(function_name, frame.f_globals, frame.f_locals)
There it is! A neat decorator to implement totally new and abusive functionality in about 50 lines (plus comments). You can find the whole thing here on github.
Examples
I also wrote up some examples in that repo to see whether it works for decently and is actually painless to use.
I implemented the maybe monad, so you can write code like this:
just = lambda x: Maybe(just=x)
nothing = Maybe()
@with_do_notation
def decrement_positives(x):
with do(Maybe) as y:
a = just(x) if x > 0 else nothing
just(a-1)
return y
print decrement_positives(0) # Nothing
print decrement_positives(1) # Just 0
print decrement_positives(2) # Just 1
I implemented the list monad, so you can write code like this:
@with_do_notation
def list_example():
list1 = ListMonad([1,2,3])
list2 = ListMonad([-1,2])
with do(ListMonad) as z:
x = list1
y = list2
mreturn(x*y)
assert(sorted(z.lst) == sorted([x*y for x in list1.lst for y in list2.lst]))
return z
print list_example() # ListMonad([-1, 2, -2, 4, -3, 6])
And finally, for a richer test, I implemented a monadic parser. It was tremendously satisfying and was a total breeze to write with this notation. I got to write code like this:
@with_do_notation
def parse_string(string):
if string == "":
return Parser.mreturn("")
if len(string) > 0:
with do(Parser) as x:
x = char(string[0])
parse_string(string[1:])
mreturn(string)
return x
return parser_zero
@with_do_notation
def sepby1(p, sep):
with do(Parser) as throwaway_a_sep:
sep
p
with do(Parser) as p2:
x = p
xs = many(throwaway_a_sep)
mreturn(x + xs)
return p2
Of course, this kind of recursive programming will exceed your max recursion depth pretty quickly, but it was great to see that, at least in principle, it’s not too hard to make it easy.