Mathematical Expressions as Trees 1: Modeling, Evaluation and Printing

I recently watched a computerphile video, where Prof. Thorsten Altenkirch talks about the data structures of trees and using them to model mathematical expressions. At the end, it is briefly mentioned, that this can also be used to perform derivations on the expressions, which inspired me to try that out myself.

So I opend up a jupyter notebook and started with an implementation analogous to what was done in the video, but a bit extended. You can find a notebook with the state after this blog post here.

The goal

Let me show you the general idea first:

I want to be able to define a mathematical expression by nesting different fundamental expressions like this:

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))

With this expression, it should than be possible to print this as a string and to evaluate its value for specific values of the variables used (in this case x and y ):

str(expr) # == '3*x^2+y-6/3'
expr.eval({'x':4, 'y':7}) # == 53

So far, this is (almost) identical to what was done in the video. In addition, I added the option to do partial derivatives and a method to simplify an expression (this will be shown in part two):

str(expr.partial('x')) # == '0*x^2+3*1*2*x^(2-1)+0-(0*3-6*0)/3^2'
str(expr.partial('x').simplify()) # == '3*2*x'

I have some further ideas for what could be added, but this is the status so far.

Base Classes

Ok, so I need classes for the all the different fundamental expressions that I want to use. Since they all need the same interface, I derive them from a common base class:

class Expr():
    def __init__(self):
        raise NotImplemented()

    def __str__(self):
        raise NotImplemented()

    def eval(self, env):
        # take a dict thats assigning values to variable names and evaluate the
        # expression with them
        raise NotImplemented()

    # The following will only be shown in part two
    def partial(self, var):
        # return the partial derivative of this expression relative to variable
        # var (given by a string)
        raise NotImplemented()

    def constValue(self):
        # if this expression is dependent on variables, return None
        # otherwise return the value of this expression (which is constant)
        raise NotImplemented()

    def simplify(self):
        # reduce the complexity of this expression and return the new expression
        raise NotImplemented()

The fundamental expressions I implemented either take no or two parameters, so I created specialised base classes for singular and binary expressions.

Singular Expressions

Singular expressions are either constants or variables. We add an empty base class for now:

class SinExpr(Expr):
    pass

The methods will need to be defined in the derived classes.

The simplest case of an expression is just a constant. We initialise it with its value, and when we evaluate it, we simply return the value:

class Con(SinExpr):
    def __init__(self, val):
        self.val = val

    def eval(self, env):
        return self.val

The other case of singular expressions we need are variables. These are initialised with a name, i.e. a string. On evaluation, the value associated to the name is looked up in the environment dictionary and returned:

class Var(SinExpr):
    def __init__(self, name):
        self.name = name

    def eval(self, env):
        return env[self.name]

Binary Expressions

Binary expressions are a bit more complicated. First of all, they obviously need to be given two other expressions in the initialisation:

class BinExpr(Expr):
    def __init__(self, a, b):
        self.a = a
        self.b = b

For evaluation, we need to evaluate both these expressions first and then combine them in the correct way. This combination will need to be defined by the derived classes in the internal method _binOperator:

    def eval(self, env):
        return self._binOperator(self.a.eval(env), self.b.eval(env))

    def _binOperator(self, a, b):
        raise NotImplemented()

Our derived classes now look like this:

class Add(BinExpr):
    def _binOperator(self, a, b):
        return a+b

class Sub(BinExpr):
    def _binOperator(self, a, b):
        return a-b

class Mul(BinExpr):
    def _binOperator(self, a, b):
        return a*b

class Div(BinExpr):
    def _binOperator(self, a, b):
        return a/b

class Pow(BinExpr):
    def _binOperator(self, a, b):
        return a**b

So far nothing too complicated, but we can already use these expressions for calculations. Let’s add a few tests, to make sure they work:

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))
assert expr.eval({'x':1, 'y':2}) == 3
assert expr.eval({'x':1, 'y':7}) == 8
assert expr.eval({'x':4, 'y':7}) == 53

Conversion to Strings

Ok, now we want to be able to print the expressions as strings, so we need to define __str__() for all our classes. For the singular expressions, this is straight forward:

class Con(SinExpr):
    # ...
    def __str__(self):
        return str(self.val)

class Var(SinExpr):
    # ...
    def __str__(self):
        return self.name

For binary expressions, we can define a bit more common internal logic in the base class. To print them as a string, we need to combine the string representations of the two expressions a and b with the fitting operation symbol in-between (so for example a*b for multiplication):

class BinExpr(Expr):
    # ...
    def __str__(self):
        a_str = str(self.a)
        b_str = str(self.b)
        return a_str + self._strOperator + b_str

    @property
    def _strOperator(self):
        raise NotImplemented()

This symbol will need to be defined by the derived classes as the property _strOperator:

class Add(BinExpr):
    # ...
    @property
    def _strOperator(self):
        return "+"

# Similar for the other binary expression classes

Now, what happens with the execution order? The tree-structure, which we create by nesting the expressions, defines it perfectly clearly. However, we need to take care that it is also visible in the string representation. Let’s say we nest a multiplication inside an addition: Add(Mul(Var('a'), Var('b')), Var('c')). The multiplication needs to be performed before the addition. It will be fine to print this case as a*b+c, since multiplication is stronger binding than addition.

In the opposite case of Mul(Var('a'), Add(Var('b'), Var('c'))), a*b+c would be a wrong representation; we need to write a*(b+c) in order to make the order of operations clear. We create a _bindingLevel property to model this: it defines which expression is stronger binding, such that we can determine, if parentheses are needed or not:

class BinExpr(Expr):
    # ...
    def __str__(self):
        a_str = str(self.a)
        b_str = str(self.b)
        if self.a._bindingLevel < self._bindingLevel:
            a_str = "(" + a_str + ")"
        if self.b._bindingLevel < self._bindingLevel:
            b_str = "(" + b_str + ")"
        return a_str + self._strOperator + b_str

    @property
    def _bindingLevel(self):
        raise NotImplemented()

While singular expressions are always strongest binding, the value for binary expressions needs to be defined in the derived classes:

class SinExpr(Expr):
    # ...
    @property
    def _bindingLevel(self):
        return 4

class Add(BinExpr):
    # ...
    @property
    def _bindingLevel(self):
        return 1

# The value will be:
# 1 for Add and Sub
# 2 for Mul and Div
# 3 for Pow

Let’s add another test with the same example as above:

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))
assert str(expr) == '3*x^2+y-6/3'

More Functionality

So far we mainly recreated what was already shown in the YouTube video and to be honest, this is not very useful yet. So I extended this base structure by a few parts. But this will be described in part two.