# ChainRules
Jeff Bezanson, Stefan Karpinski, Viral Shah, Alan Edelman;

# How does Forward Mode AD work? Forward-mode AD means replacing every function with a function that calculates the primal result and pushesforward the derivative.

How do we get such a function? Either we have a `frule` giving us one, or we open up the function and replace every function inside it with such a propagating function.

## Lets do AD by hand: forward-mode ```julia function foo(x) u = sin(x) v = asin(u) return v end ``` ``` foo (generic function with 1 method) ```

```julia-repl julia> x = π/4; julia> ẋ = 1.0; julia> u, u̇ = frule((NO_FIELDS, ẋ), sin, x) (0.7071067811865475, 0.7071067811865476) julia> v, v̇ = frule((NO_FIELDS, u̇), asin, u) (0.7853981633974482, 1.0) julia> v̇ 1.0 ```

$\dot{x}=\textcolor{blue}{\dfrac{\partial x}{\partial x}}$
$\dot{u}= \textcolor{green}{\dfrac{\partial u}{\partial x}} =\dfrac{\partial u}{\partial x} \textcolor{blue}{\dfrac{\partial x}{\partial x}}$
$\dot{v}= \textcolor{purple}{\dfrac{\partial v}{\partial x}} =\dfrac{\partial v}{\partial u} \textcolor{green}{\dfrac{\partial u}{\partial x}}$ ] ] --- # How does Reverse Mode AD work? Reverse-mode AD means replacing every function with a function that calculates the primal result and stores the pullback onto a tape, which it then composes backwards at the end to pull all the way back. -- How do we get such a function that tells us the pullback? Either we have a `rrule` giving us one, or we open up the function and replace every function inside it with such a propagating function. --- ## Lets do AD by hand: reverse-mode ```julia function foo(x) u = sin(x) v = asin(u) return v end ``` First the forward pass, computing the pullbacks, which we would record onto the tape ```julia-repl julia> x = π/4 0.7853981633974483 julia> u, u_pullback = rrule(sin, x) (0.7071067811865475, ChainRules.var"#477#sin_pullback#117"{Float64}(0.7853981633974483)) julia> v, v_pullback = rrule(asin, u) (0.7853981633974482, ChainRules.var"#501#asin_pullback#123"{Float64}(0.7071067811865475)) ``` --- ## Lets do AD by hand: reverse-mode ```julia function foo(x) u = sin(x) v = asin(u) return v end ``` Then the backward pass calculating gradients .row[ .col[
```julia-repl julia> v̅ = 1; julia> _, u̅ = v_pullback(v̅) (ChainRulesCore.Zero(), 1.414213562373095) julia> _, x̄ = u_pullback(u̅) (ChainRulesCore.Zero(), 1.0) julia> x̄ 1.0 ``` ] .col[ $\bar{v}=\textcolor{blue}{\dfrac{\partial v}{\partial v}}$
$\bar{u}=\textcolor{green}{\dfrac{\partial v}{\partial u}} =\textcolor{blue}{\dfrac{\partial v}{\partial v}}\dfrac{\partial v}{\partial u}$
$\bar{x}= \textcolor{purple}{\dfrac{\partial v}{\partial x}} =\textcolor{green}{\dfrac{\partial v}{\partial u}} \dfrac{\partial u}{\partial x}$
] ] ---
# A series of needs
.image-30[] ] ] --- # What does AD Need ? * Ability to decompose functions down into primitive operations that it has **rules** for. * Ability to recompose those rules and the results to get overall derivatives. * A collection of those **rules**: ChainRules --- # Why does AD need rules: * Fundamentally need rules for the instruction set: `+`, `*`, etc. * Insert domain-knowledge about best way to find it. Extreme example: *QuadGK + Fundamental Theorem of Calculus.* it is the identity. * Need rules to handle things the AD can't deal with (e.g. Zygotes current lack of mutation support) .funfact[ Instruction set size varies a lot across ADs. E.g. PyTorch, Jax, MxNet have an instruction set that is $\approx~\text{Numpy's API}$. Enzyme has an instruction set that is the LLVM instruction set. ] --- # What does ChainRules Need? -- ### An AD Agnostic System for Writing Rules ChainRulesCore.jl -- ### An inventory of actual rules for Base and StdLibs ChainRules.jl -- ### A way to test they are right ChainRulesTestUtils.jl --- # The ChainRules project fills those needs
--- # What does ChainRulesCore need? A way to specify what the rule is for a given method: i.e. function + argument types. This is done by overloading `frule` and `rrule`: `rrule(::typeof(foo), args...; kws...)` `frule((ṡelf, ȧrgs...), ::typeof(foo), args...; kws...)` --- # What does a `frule` need? We know we are going to need to compute the primal; so we need the primal inputs. What else do we need to allow us to propagate the directional derivative forwards? We need that directional derivative being pushedforward. ```julia function frule((ṡelf, ȧrgs...), ::typeof(foo), args...; kwargs...) ... return y, ẏ end ``` .funfact[ We say that the **pushforward is fused into the frule**. This is required for efficient custom rules e.g. for ODE solvers. ] --- # What does an `rrule` need? Primal inputs again, but what else do we need to propagate gradient backwards? We need the gradient of the function called after this one. That's a problem, we don't have that on the forward pass, so we will need to return something to put on the tape for the backwards pass. The **pullback**. ```julia function rrule(::typeof(foo), args...; kwargs...) y = ... function foo_pullback(ȳ) ... return s̄elf, ārgs... end return y, foo_pullback end ``` --- # What do we need to represent the types of derivatives? .row[ .col[ **Primal**
```julia struct Foo a::Matrix{Float64} b::String end ``` ] .col[ **Differential**
```julia Composite{Foo} # With properties: a::Matrix{Float64} b::DoesNotExist ``` ] ] .funfact[ There are multiple correct differentials for many types and which to use is context dependent. ] --- # What do differential types need? Basically they are elements of vector spaces. Roughly speaking, **every differential represents the difference between two primals**. .funfact[ The differential for `DateTime` is `Period` (e.g. `Millisecond`). It is my favorite example of a differential for a primal that is *not* a vector space. ] --- # What do differential types need? `zero` They need a **zero**. Since the primals that it is the difference of could be equal. E.g. when the function being differentiated is a constant. .funfact[ There is thus the trival differential. `Zero()`, which can be added to anything and it won't change. Its a valid differential for all primals. ] --- # What do differential types need? `+` They need to be able to be added to each other. For $u = \sin(x)$, $v = \cos(x)$, $y = u + v$ $$ \dfrac{\partial y}{\partial x} = \dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial x} $$ .funfact[ An advantage of ChainRule's differentiable types over Zygotes use of `NamedTuples` (Chainrule's `Composite`) and `Nothing` (ChainRule's `AbstractZero`) is that they actually overload `+` ] --- # What do differential types need, to be useful for gradient based optimization Vanilla Gradient Descent: $ x \leftarrow x + 0.1 \tilde x $ So need to be able to **add to primal**, and **multiply by a scalar** Add to primal is the inverse of it being a difference of primals. Multiply by scalar is natural to define from limits of additions. .funfact[This is also useful **difference** based (gradient-free) optimization like **particle swarms** and the **Nelder–Mead method**. Further, it is planned to switch FiniteDifferences.jl to use differentials, rather than converting to and from vectors. ] ---
# How is the world now?
# How many AD systems do we have?

**Reverse Mode:** * AutoGrad * Nabla * ReverseDiff * Tracker * Yota * Zygote

**Forward Mode:** * ForwardDiff * ForwardDiff2 * Zygote
.funfact[The intersection of Zygote and Nabla's rules is not much more than what both have from DiffRules.]
# ChainRules vs DiffRules: **ChainRules** is the successor to **DiffRules** * **DiffRules** only handles *scalar rules* * Any rule defined in **DiffRules** can be defined using the `@scalar_rule` macro in **ChainRules** without change. * **DiffRules** is *not designed* to have its list of rules extended by other packages. **ChainRulesCore** is.

# ChainRules vs ZygoteRules: * **ChainRulesCore** and **ZygoteRules** are very similar * ChainRules wasn't quite ready when ZygoteRules was created. * ChainRules is not Zygote specific, it works with everything.

### ZygoteRules is effectively deprecated, and all new rules should be written using ChainRulesCore
.image-30[  ] ] .col[ ### ZygoteRules is effectively deprecated, and all new rules should be written using ChainRulesCore ] ---
# The Future
## Deeper Integration into Zygote. * Use in Forward Mode * Use ChainRule's differential types * `nothing` -> `AbstractZero` * `NamedTuple` -> `Composite`
* Convenience macro for easy translating of ZygoteRules

# Better support for Overloading based AD * Need to improve support for generating overloads from rules. * Will also solve inference related issues. ### 🔜 ReverseDiff.jl ### 🔜 Nabla.jl

## ForwardDiff 🤷 *maybe* Its so stable. ForwardDiff2 might take its place. .funfact[ ForwardDiff was released on 13 April 2013. Julia v0.2 was released 19 November 2013. 31 weeks later. Which included such features as Pkg2, keyword arguments, and suffixing mutating functions with `!`. ]

## Calling back into AD ```julia function rrule(::typeof(map), f, x) res = map(xi->rrule(f, xi), x) ys, pullbacks = unzip(res) function map_pullback(ȳ) s̄elf, x̄ = unzip(map(pullbacks, ȳ) do pullback_i, ȳi pullback_i(ȳi) end return NO_FIELDS, s̄elf, x̄ end return y, map_pullback end ``` **But this calls `rrule(f, xi)` which might not be defined**. May need to use an AD to find it. Could hard-code a given AD, but probably you want to keep using the one you are already using when you called `rrule(map, f, x)`

# Rules Everywhere Just like **TimeZones.jl** depends on **RecipesBase.jl** to make `ZonedDateTimes` plot-able. Packages like **DiffEqBase** depends on **ChainRulesCore.jl** and provide rules to make their functions differentiable, where required or where there are smart domain-knowledge ways to make it faster. The future is more packages doing that.

# How to get involved * Write rules for Base and StdLibs in ChainRules * Write rules for your package with ChainRulesCore * Incorporate ChainRules support in to your favourate AD .funfact[There are 500 rules in Zygote that need migrating.]

# Summary * AD needs rules * There will always be more AD systems. * One set of rules to rule them all
.image-30[] ] ]