The ideas in this post were hashed out during a series of discussions between myself and Bruno Gavranović

Consider a system for forecasting a time series in \(\mathbb{R}\) based on a vector of features in \(\mathbb{R}^a\). At each time \(t\) this system will use the state of the world (represented as a vector in \(\mathbb{R}^a\)) to predict what the value of the time series (a real number) will be at time \(t+1\). At time \(t+1\) the system will receive information about the correct value of the time series at time \(t+1\), represented as a pair \((x_a, y) \in \mathbb{R}^a \times \mathbb{R}\), and will need to predict the value of the time series at time \(t+2\).

One way to build such a system would be to choose a loss function \(l: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\) and use gradient descent to train a model \(f: \mathbb{R}^p \times \mathbb{R}^a \rightarrow \mathbb{R}\) on all of the data before some time \(t\). We could then use that trained model to generate predictions at time \(t+n\). If the time series is not stationary then this may produce poor results for \(t+n\) where \(n\) is large.

Another option would be to continuously update the parameters \(x_p \in \mathbb{R}^p\) of the model \(f: \mathbb{R}^p \times \mathbb{R}^a \rightarrow \mathbb{R}\) as each new sample \((x_a, y) \in \mathbb{R}^a \times \mathbb{R}\) is observed. This is known as online learning.

An example online learning algorithm is stochastic gradient descent, which we can define as follows:

Initialize \(x_p\) randomly, and then as each sample \((x_{a_i}, y_{i})\) arrives set:

\[i \gets mod(i+1, |S|)\\ l_{f_i} \gets l(f(x_p, x_{a_i}), y_{i})\\ x_p \gets x_p - \alpha \nabla l_{f_i}(x_p)\]

Stochastic gradient descent describes a dynamical system in which a state in \(\mathbb{R}^p\) evolves over time in response to inputs in \(\mathbb{R}^a \times \mathbb{R}\). One of the downsides of stochastic gradient descent is that the update step can be very high variance from \(t\) to \(t+1\), which can slow down convergence. One way to get around this is to use the momentum algorithm, which we define as follows:

Initialize \(x_p, x_p'\) randomly, and then as each sample \((x_{a_i}, y_{i})\) arrives set:

\[l_{f_i} \gets l(f(x_p, x_{a_i}), y_{i})\\ x_p \gets x_p + \alpha x'_p \\ x'_p \gets (1 - \beta) x'_p - \beta \nabla l(x_p) \\\]

Momentum describes a dynamical system in which a state in \(\mathbb{R}^p \times \mathbb{R}^p\) evolves over time in response to inputs in \(\mathbb{R}^a \times \mathbb{R}\).

Lenses and Dynamical Systems

Optimization algorithms like stochastic gradient descent and momentum describe dynamical systems whose state is the function parameters. In this section we dig deeper into this perspective. In particular, we use David Jaz Myer’s category theoretic formulation of dynamical systems to study how we can recombine simpler optimization algorithms to form a more complex algorithm.

The category theoretic formulation of dynamical systems is based on lenses, which are a tool for representing certain kinds of compositions. We will focus entirely on lenses in the category of sets and functions.

A lens \(\left(_{A}^{A'}\right) \xrightarrow{(f_g, f_p)} \left(_{B}^{B'}\right)\) in the category of sets and functions \(\mathbf{Set}\) is a pair \((f_g, f_p)\) of morphisms (functions):

\[f_g : A \rightarrow B \qquad f_p : A \times B' \rightarrow A' \\\]

Lenses are powerful because many computations can be expressed in terms of the combination of multiple lenses. The simplest way to combine lenses is to stack them in parallel. Given the lenses:

\[\left(_{A}^{A'}\right) \xrightarrow{(f_g, f_p)} \left(_{B}^{B'}\right) \\ \left(_{C}^{C'}\right) \xrightarrow{(g_g, g_p)} \left(_{D}^{D'}\right)\]

we define their monoidal product to be the lens:

\[\left(_{A \otimes C}^{A' \otimes C'} \right) \xrightarrow{(h_g, h_p)} \left(_{B \otimes D}^{B' \otimes D'}\right) \\ h_g = f_g \otimes g_g \\ h_p = \langle f_p \circ (\pi_0 \otimes \pi_0) , (g_p \circ (\pi_1 \otimes \pi_1)\rangle\]

We can also compose lenses directly. Given the lenses:

\[\left(_{A}^{A'}\right) \xrightarrow{(f_g, f_p)} \left(_{B}^{B'}\right) \\ \left(_{B}^{B'}\right) \xrightarrow{(g_g, g_p)} \left(_{C}^{C'}\right)\]

we define their composition to be the lens:

\[\left(_{A}^{A'}\right) \xrightarrow{(h_g, h_p)} \left(_{C}^{C'}\right) \\ h_g = g_g \circ f_g \\ h_p = f_p \circ \langle \pi_0, (g_p \circ ((f_g\circ \pi_0) \otimes \pi_1))\rangle\]

We can now characterize dynamical systems as lenses. A discrete system is a lens \(\left(_{S}^{S}\right) \xrightarrow{(f_g, f_p)} \left(_{O}^{I}\right)\) or equivalently, a set of states \(S\), a set of inputs \(I\), a set of outputs \(O\), and two functions:

  • A get (read) function \(f_g: S \rightarrow O\) that generates an output from a state.
  • A put (update) function \(f_p: S \times I \rightarrow S\) that takes a pair of a state and an input and returns an updated state.

Intuitively, a discrete system represents the stepwise application of the update function \(f_p\), potentially in response to a sequence of inputs. That is, the discrete system \(\left(_{S}^{S}\right) \xrightarrow{(f_g, f_p)} \left(_{O}^{I}\right)\) describes a dynamical system whose state \(x_{S_t} \in S\) at time \(t+1\) is described by the equation:

\[x_{S_{t+1}} = x_{S_t} + f_p(x_{S_t}, x_{I_t}) \\\]

where \(x_{I_t} \in I\) is the system input at time \(t\).

Constructing Momentum from Stochastic Gradient Descent

Given a pair of a loss function \(l: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\) and inference function \(f: \mathbb{R}^p \times \mathbb{R}^a \rightarrow \mathbb{R}\) the stochastic gradient descent dynamical system \(sg\) has the following structure:

\[\left(_{\mathbb{R}^{p}}^{\mathbb{R}^{p}}\right) \xrightarrow{(sg_g, sg_p)} \left(_{\mathbb{R}^p}^{\mathbb{R}^a \times \mathbb{R}}\right) \\ sg_g(x_p) = x_p \\ sg_p(x_p, (x_a, y)) = -\nabla l(f(x_p, x_a), y)\]

This system iteratively updates its state \(x_p\) each time a new sample \((x_a, y) \in \mathbb{R}^a \times \mathbb{R}\) is observed.

We can also define a dynamical system to represent stochastic momentum. Given a pair of a loss function \(l: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\) and inference function \(f: \mathbb{R}^p \times \mathbb{R}^a \rightarrow \mathbb{R}\) the momentum dynamical system \(sm\) has the following structure:

\[\left(_{\mathbb{R}^{p} \times \mathbb{R}^{p}}^{\mathbb{R}^p \times \mathbb{R}^{p}}\right) \xrightarrow{(sm_g, sm_p)} \left(_{\mathbb{R}^p}^{\mathbb{R}^a \times \mathbb{R}}\right) \\ sm_g((x_p, x'_p)) = x_p \\ sm_p((x_p, x'_p), (x_a, y)) = (x'_p, - x'_p -\nabla l(f(x_p, x_a), y))\]

We can construct momentum from the composition and tensor of stochastic gradient descent with some basic lenses. To start, consider the following discrete systems.

  • The discrete system \(add\) reads the state directly and uses the sum of the input values as the state update:
\[\left(_{\mathbb{R}^{p}}^{\mathbb{R}^{p}}\right) \xrightarrow{(add_g, add_p)} \left(_{\mathbb{R}^p}^{\mathbb{R}^p \times \mathbb{R}^p}\right) \\ add_g(x_p) = x_p \\ add_p(x_p, (x'_p,x''_p)) = x'_p+x''_p\]
  • The discrete system \(cp\) reads the state directly and uses the current state as the state update:
\[\left(_{\mathbb{R}^{p}}^{\mathbb{R}^{p}}\right) \xrightarrow{(cp_g, cp_p)} \left(_{\mathbb{R}^p}^{*}\right) \\ cp_g(x_p) = x_p \\ cp_p(x_p, *) = x_p\]
  • The discrete system \(sw\) reads the state directly and swaps the positions of the input values to generate the state update:
\[\left(_{\mathbb{R}^p \times \mathbb{R}^p}^{\mathbb{R}^p \times \mathbb{R}^p}\right) \xrightarrow{(sw_g, sw_p)} \left(_{\mathbb{R}^p \times \mathbb{R}^p}^{\mathbb{R}^p \times \mathbb{R}^p}\right) \\ sw_g(x_p, x'_p) = (x_p, x'_p) \\ sw_p((x_p, x'_p), (x''_p, x'''_p)) = (x'''_p, x''_p)\]

Consider also the following lenses:

  • The lens \(ng\) uses the negated input value as the state update:
\[\left(_{*}^{\mathbb{R}^{p}}\right) \xrightarrow{(ng_g, ng_p)} \left(_{*}^{\mathbb{R}^p}\right) \\ ng_g(*) = * \\ ng_p(*, x_p) = -x_p\]
  • The lens \(srp\) reads the left component of the system state and uses the right component to generate the state update:
\[\left(_{\mathbb{R}^p \times \mathbb{R}^p}^{\mathbb{R}^a \times \mathbb{R} \times \mathbb{R}^{p}}\right) \xrightarrow{(srp_g, srp_p)} \left(_{\mathbb{R}^p}^{\mathbb{R}^a \times \mathbb{R}}\right) \\ srp_g(x_p, x'_p) = x_p \\ srp_p((x_p, x'_p), (x_a, y)) = (x_a, y, x'_p)\]

We can now construct the momentum dynamical system \(sm\) as the following composition:

\[sm = srp \circ (((sg \times ng) \circ add) \times cp) \circ sw \\\]

Let’s break this down into its component parts to see this more clearly. We can draw the composition:

\[\left(_{\mathbb{R}^{p}}^{\mathbb{R}^{p}}\right) \xrightarrow{ (((sg \times ng) \circ add)_g, ((sg \times ng) \circ add)_p)} \left(_{\mathbb{R}^p}^{\mathbb{R}^a \times \mathbb{R} \times \mathbb{R}^{p}}\right) \\\]

where:

\[((sg \times ng) \circ add)_g(x_p) = (sg \times ng)_g(add_g(x_p)) = (sg \times ng)_g(x_p) = x_p \\\]

and:

\[\begin{aligned} ((sg \times ng) \circ add)_p(x_p, (x_a, y, c_p)) = \\ add_p(x_p, (sg \times ng)_p(add_g(x_p), (x_a, y, c_p))) = \\ add_p(x_p, sg_p(add_g(x_p), (x_a, y)), ng_p(c_p)) = \\ add_p(x_p, sg_p(x_p, (x_a, y)), -c_p) = \\ add_p(x_p, -\nabla l(f(x_p, x_a), y), -c_p) = \\ -c_p - \nabla l(f(x_p, x_a), y) \end{aligned}\]

We can build on this to form:

\[\left(_{\mathbb{R}^{p} \times \mathbb{R}^{p}}^ {\mathbb{R}^p \times \mathbb{R}^{p}}\right) \xrightarrow{ ((((sg \times ng) \circ add)\times cp)_g, (((sg \times ng) \circ add)\times cp)_p) } \left(_{\mathbb{R}^p \times \mathbb{R}^p}^ {\mathbb{R}^a \times \mathbb{R} \times \mathbb{R}^p}\right) \\ (((g \times ng) \circ add)\times cp)_g(x_p, x'_p) = (x_p,x'_p) \\ (((g \times ng) \circ add)\times cp)_p((x_p, x'_p), (x_a, y, c_p))) = (-c_p - \nabla l(f(x_p, x_a), y), x'_p)\]

We can further build on this to form:

\[\left(_{\mathbb{R}^{p} \times \mathbb{R}^{p}}^ {\mathbb{R}^p \times \mathbb{R}^{p}}\right) \xrightarrow{ (((((sg \times ng) \circ add)\times cp) \circ sw)_g, ((((sg \times ng) \circ add)\times cp) \circ sw)_p) } \left(_{\mathbb{R}^p \times \mathbb{R}^p}^ {\mathbb{R}^a \times \mathbb{R} \times \mathbb{R}^p}\right) \\ ((((sg \times ng) \circ add)\times cp) \circ sw)_g(x_p, x'_p) = (x_p,x'_p) \\ ((((sg \times ng) \circ add)\times cp) \circ sw)_p((x_p, x'_p), (x_a, y, c_p))) = (x'_p, -c_p - \nabla l(f(x_p, x_a), y))\]

Putting it all together we have:

\[\left(_{\mathbb{R}^{p} \times \mathbb{R}^{p}}^ {\mathbb{R}^p \times \mathbb{R}^{p}}\right) \xrightarrow{ ((srp \circ ((((sg \times ng) \circ add)\times cp) \circ sw))_g, (srp \circ ((((sg \times ng) \circ add)\times cp) \circ sw))_p) } \left(_{\mathbb{R}^p}^ {\mathbb{R}^a \times \mathbb{R}}\right) \\ (srp \circ ((((sg \times ng) \circ add)\times cp) \circ sw))_g(x_p, x'_p) = x_p = sm_g((x_p, x'_p)) \\ (srp \circ ((((sg \times ng) \circ add)\times cp) \circ sw))_p((x_p, x'_p), (x_a, y)) = (x'_p, -x'_p - \nabla l(f(x_p, x_a), y)) = sm_p((x_p, x'_p), (x_a, y))\]

Conclusions

In this post we explored how we can leverage the composition of dynamical systems to construct complex optimization algorithms from simpler components. In particular, we demonstrated that the momentum optimization algorithms can be constructed from nothing more than stochastic gradient descent and some simple lens operations. It should be simple to extend this strategy to other algorithms like Adagrad and Adam.

Furthermore, we can probably utilize this dynamical systems perspective to reason about the relationship between the optimization process and the data that we feed in from the outside. We can similarly represent the optimization hyperparameters or the configuration of the data generation process as the dynamical system input.