Mathematical Expressions as Trees 2: Derivations and Expression Simplification

If you haven’t done so yet, read part one first or watch the video that inspired these posts.

You can find a notebook with the state after this blog post here.

More Functionality

So far we mainly recreated what was already shown in the YouTube video and to be honest, this is not very useful yet. So I extended this base structure by a few parts.

Partial Derivatives

The obvious extension, which was triggered by the video, is the addition of partial derivatives. I’ve already shown in part one how this is supposed to work:

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))
str(expr.partial('x')) # == '0*x^2+3*1*2*x^(2-1)+0-(0*3-6*0)/3^2'

This does not look nice, but it is correct. The looks will be the second step.

In principle, this is not too complicated. We just need to implement the method partial(var) for every expression class, based on what we know about derivatives from analysis.

The simplest cases are again the singular expressions, where the derivative of a constant will be 0 and the derivative of a variable is either 0 or 1, dependent on the variable:

class Con(SinExpr):
    # ...
    def partial(self, var):
        return Con(0)

class Var(SinExpr):
    # ...
    def partial(self, var):
        if var == self.name:
            return Con(1)
        else:
            return Con(0)

The more interesting cases are the binary expressions. Here the derivations depend on the derivations of the contained expressions recursively:

class Add(BinExpr):
    # ...
    def partial(self, var):
        return Add(self.a.partial(var), self.b.partial(var))

class Sub(BinExpr):
    # ...
    def partial(self, var):
        return Sub(self.a.partial(var), self.b.partial(var))

For products and divisions we use the chain rule:

class Mul(BinExpr):
    # ...
    def partial(self, var):
        return Add(Mul(self.a.partial(var), self.b),
                   Mul(self.a, self.b.partial(var)))

class Div(BinExpr):
    # ...
    def partial(self, var):
        return Div(Sub(Mul(self.a.partial(var), self.b),
                       Mul(self.a, self.b.partial(var))),
                   Pow(self.b, Con(2)))

For powers, we use (x^n)' = n * x^(n-1) and the chain rule together. However, we need to make sure, that the derivative of the exponent is 0, otherwise we have an exponential function (which we could in principle allow for, but for simplicity it is left out for now):

class Pow(BinExpr):
    # ...
    def partial(self, var):
        if not self.b.partial(var).constValue() == 0:
            raise ValueError("The exponent is dependent on {}!".format(var))
        return Mul(self.a.partial(var),
                   Mul(self.b,
                       Pow(self.a, Sub(self.b, Con(1)))))

(The method constValue() will be explained in the next section)

With this, we can already derive by different variables. Let’s test it:

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))
assert str(expr.partial('x')) == '0*x^2+3*1*2*x^(2-1)+0-(0*3-6*0)/3^2'
assert str(expr.partial('y')) == '0*x^2+3*0*2*x^(2-1)+1-(0*3-6*0)/3^2'

Simplification of Expressions

The partial derivatives work now, but the expressions we get are not very pretty. You would probably expect something like this:

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))
assert str(expr.partial('x')) == '6*x'
assert str(expr.partial('y')) == '1'

Let’s try to get as close to this ideal as possible. For this purpose we create the method simplify(), which should return a expression that is mathematical identical to the original, but expressed in a simpler way.

Singular expressions are either constants or variables and can not be further simplified, so we can implement that method already in the base class:

class SinExpr(Expr):
    def simplify(self):
        return self

The more challenging part are the binary expressions. The simplest thing one can do here is to check, whether we can already calculate the value of the expression and just replace it with a Con expression. For this purpose, we define another property: constValue() should return the value of an expression, if it is calculable and None otherwise. For singular expressions this is again trivial:

class Con(SinExpr):
    # ...
    @property
    def constValue(self):
        return self.val

class Var(SinExpr):
    # ...
    @property
    def constValue(self):
        return None

For binary expressions, we need to recursively check, if the two child expressions are constants, then we either calculate the value of the binary expression or return None:

class BinExpr(Expr):
    # ...
    @property
    def constValue(self):
        val_a = self.a.constValue()
        val_b = self.b.constValue()
        if val_a and val_b:
            return self._binOperator(val_a, val_b)
        else:
            return None

Great! We can now use this to define the simplify() method for binary expressions:

class BinExpr(Expr):
    # ...
    def simplify(self):
        # recursively simplify children
        self.a = self.a.simplify()
        self.b = self.b.simplify()
        # if children are constant, return self.constValue
        if self.constValue:
            return Con(self.constValue)
        # otherwise try special tricks of derived class
        return self._selfSimplify()

    def _selfSimplify(self):
        raise NotImplemented()

What is this _selfSimplify()? Shouldn’t we just return self there? Well, even if we can’t define a constant value for this expression, we might be able to do some tricks that we know from algebra. For example, if we add two things and one of them is zero, we can just return the other:

class Add(BinExpr):
    # ...
    def _selfSimplify(self):
        if self.cv_a == 0:
            return self.b
        if self.cv_b == 0:
            return self.a
        return self

If that trick is not applicable, we indeed default to returning self. Similar tricks can be applied to the other expressions as well:

class Sub(BinExpr):
    # ...
    def _selfSimplify(self):
        if self.cv_a == self.cv_b:
            return Con(0)
        if self.cv_b == 0:
            return self.a
        return self

class Mul(BinExpr):
    # ...
    def _selfSimplify(self):
        if self.cv_a == 0 or self.cv_b == 0:
            return Con(0)
        if self.cv_a == 1:
            return self.b
        if self.cv_b == 1:
            return self.a
        return self

class Div(BinExpr):
    # ...
    def _selfSimplify(self):
        if self.cv_b == 0:
            raise ValueError("Division by zero")
        if self.cv_a == 0:
            return Con(0)
        if self.cv_b == 1:
            return self.a
        return self

class Pow(BinExpr):
    # ...
    def _selfSimplify(self):
        if self.cv_b == 0:
            return Con(1)
        if self.cv_b == 1:
            return self.a
        return self

Fine Tuning

Ok, so what do we get now, if we simplify our expression from above?

# 3*x^2 + y - 6/3
expr = Add(Mul(Con(3),Pow(Var('x'),Con(2))),Sub(Var('y'),Div(Con(6),Con(3))))
str(expr.simplify()) # returns '3*x^2+y-2.0'

The 2.0 bothers me a little. It is not wrong, but a clean 2 would be nicer. Also, if we have fractions like 1/3, it would be good to keep this as a fraction instead of 0.3333333333333333.

For the first point, we can extent the initializer of Con:

class Con(SinExpr):
    def __init__(self, val):
        self.val = int(val) if int(val) == val else val

For the second point, we extend the simplify() method in BinExp. Let’s say we want at most three decimal digits printed:

class BinExpr(Expr):
    # ...
    def simplify(self):
        # ...
        if self.constValue and \
                round(self.constValue, 3) == self.constValue:
            return Con(self.constValue)
        # ...

This is now good enough, to fulfil the following tests:

assert str(Con(2.0).simplify()) == '2'
assert str(Div(Con(1), Con(3)).simplify()) == '1/3'
assert str(Div(Con(1), Con(4)).simplify()) == '0.25'

Outlook

There is still one thing missing in the simplification (that I know of). Look at this expression:

str(Mul(Con(2), Mul(Con(3), Var('x'))).simplify()) # returns '2*3*x'
str(Mul(Mul(Con(2), Con(3)), Var('x')).simplify()) # returns '6*x'

Whereas both really should return the second option. But this can wait for part three.