The brain can be viewed as a computing or information processing machine where the relevant information is extracted from the sensory inputs. Similarly to our computers, the computations can be described as a sequence of instructions with algorithms. However, the brain is also very different from our computers and others prefer to view it as a dynamical system with continuous time dynamics at multiple temporal and spatial scales. How can we combine these two points of view ?

In this blog post, I will try to provide some intuition about how the dynamics of a neural network can compute something. As an example I will consider the Principal Component Analysis (PCA) method, which is a classical method for dimensionality reduction. The idea is to project high-dimensional data onto a smaller subspace in a way that captures as much as possible the relevant aspects of the data, and we will see what does that mean more precisely. After reviewing the basics of PCA, we will use a formulation of PCA as an optimization problem to build a continuous time dynamical system that computes this algorithm. The animation above hints the results obtained with the dynamical system that we will derive. We will see the importance of having two separate time scales for the dynamics and we will conclude that the principle used to build this simple example might be fundamental to better understand brain dynamics.

The code for this post can be found here and don’t hesitate to contact me by email if you have any questions or remarks.

Principal Component Analysis

Let’s consider mm data points x(1),...,x(m)\mathbf{x}^{(1)}, ..., \mathbf{x}^{(m)} in Rn\mathbb{R}^n drawn independently and identically (i.d.d.) from a distribution pdatap_{data}, which is often called the data generating distribution. The PCA method assumes that the relevant directions are those along which the data vary the most, while those along which the data vary the least are considered as noise. It is important to note that this is an assumption and it will be satisfied or not depending on what your data represents and what you want to do with it. We can visualize more concretely what is a direction of maximal variance with a small example in R2\mathbb{R}^2. On the figure below is shown some data points sampled from a Gaussian distribution and we can see that there is a direction along which the data vary the most, depicted by the solid line segment.

The data x(1),...,x(m)\mathbf{x}^{(1)}, ..., \mathbf{x}^{(m)} are assumed to be centered (zero mean) and let’s place them in a matrix XX such that the ithi^{\text{th}} row is given by (x(i))T(\mathbf{x}^{(i)})^T. The goal of PCA is to find a subspace of dimension rr that captures as much variance as possible and to find an uncorrelated basis of this subspace. We will need the notion of the covariance matrix, so here is a short remark for those who are not familiar with it.

Remark: The covariance of two random variables x1x_1 and x2x_2 is defined as cov(x1,x2)=E[(x1E[x1])(x2E[x2])]\text{cov}(x_1, x_2) = \mathbb{E}[(x_1 - \mathbb{E}[x_1])(x_2 - \mathbb{E}[x_2])], which is equal to E[x1x2]\mathbb{E}[x_1x_2] if the variables have been centered beforehand so that they have zero mean. The entry at the ithi^{\text{th}} row, jthj^{\text{th}} column of the covariance matrix of a vector of random variables x=(x1,...,xn)\mathbf{x} = (x_1, ..., x_n) is given by cov(xi,xj)\text{cov}(x_i, x_j). The expectation over a random variable can be approximated by an average over data points sampled i.d.d. from its distribution. Thus, for centered data, cov(xi,xj)1mk=1mxi(k)xj(k)\text{cov}(x_i, x_j) \approx \frac{1}{m}\sum_{k=1}^m x_i^{(k)}x_j^{(k)} and the covariance matrix can be approximated by 1mXTX\frac{1}{m}X^TX. From now on, I will use this approximation but won’t make explicitly the distinction between random variable and data sample.

Let’s start by computing the singular value decomposition (SVD) X=UΣVTX = U\Sigma V^T where UU and VV are orthogonal matrices, and use it to define the following change of variable

y=VTx\mathbf{y} = V^T\mathbf{x}

This change of variable can be applied on each example x(i)\mathbf{x}^{(i)} in XX to obtain a matrix Y=XVY=XV. Then using the fact that XTX=VΣ2VTX^TX = V\Sigma^2 V^T, the covariance matrix of y\mathbf{y} is given by

1nYTY=1nΣ2\dfrac{1}{n}Y^TY = \frac{1}{n}\Sigma^2

Since Σ\Sigma is diagonal, this covariance matrix is diagonal. So we have nn uncorrelated random variables yiy_i with respective variance given by the diagonal elements of Σ2\Sigma^2, i.e. var(yi)=σi2\text{var}(y_i) = \sigma_i^2. From a geometrical point of view, the variable yiy_i gives the component of the vector x\mathbf{x} projected onto the basis vector vi\mathbf{v}_i which is the ithi^{\text{th}} column of VV. Hence, the direction of maximal variance is given by the vector v1\mathbf{v}_1 associated with the maximal singular value σ1\sigma_1 (the singular values are ordered in decreasing order) and the corresponding variance is var(y1)=σ12\text{var}(y_1) = \sigma_1^2. More generally, the subspace spanned by the rr first basis vectors v1,...,vr\mathbf{v}_1, ..., \mathbf{v}_r associated with the rr largest singular values σ1,...,σr\sigma_1, ..., \sigma_r is the subspace of dimension rr that captures as much variance of the data as possible. Those rr basis vectors are called the rr first components of the data.

Optimization formulation

To build a dynamical system that computes the solution obtained above, it will be useful to express this solution as the solution of an optimization problem. Above we defined the vector yRn\mathbf{y}\in\mathbb{R}^n as y=VTx\mathbf{y}=V^T\mathbf{x}, which contains the components of the orthogonal projection of the vector x\mathbf{x} onto the columns of VV. As the goal is to compute the rr-dimensional subspace of maximal variance, we are just interested in computing for each data point x\mathbf{x},

z=WTx\mathbf{z} = W^T \mathbf{x}

where WRn×rW \in \mathbb{R}^{n\times r} is the matrix containing the rr first columns of VV, and zRr\mathbf{z} \in \mathbb{R}^r contains the components of the projection of x\mathbf{x} onto the columns of WW. Now let’s show that this definition of z\mathbf{z} and WW is the solution of an optimization problem. To this end, we will use the following theorem.

Theorem (Low-rank approximation): The solution of

minX^X^XF2s.t.  rank(X^)r\min_{\hat{X}} ||\hat{X} - X||^2_F \quad \text{s.t. } \ \text{rank}(\hat{X}) \leq r

is X^=σ1u1v1T+...+σrurvrT\hat{X} = \sigma_1 \mathbf{u}_1\mathbf{v}_1^T + ... + \sigma_r \mathbf{u}_r\mathbf{v}_r^T, where X=UΣVTX = U\Sigma V^T is the SVD decomposition of XX.

Let’s start with arbitrary z(i)\mathbf{z}^{(i)} and matrix WW and let’s consider the minimization of the average reconstruction error between a data point x(i)\mathbf{x}^{(i)} and the vector Wz(i)W \mathbf{z}^{(i)},

minz(i),W 121mi=1mWz(i)x(i)22=121mX^XF2=:R(z(1),...,z(m),W)(1)\min_{\mathbf{z}^{(i)}, W} \ \frac{1}{2}\frac{1}{m}\sum_{i=1}^m||W\mathbf{z}^{(i)}-\mathbf{x}^{(i)}||_2^2 = \frac{1}{2}\frac{1}{m} ||\hat{X}-X||^2_F \tag{1} =: R(\mathbf{z}^{(1)}, ..., \mathbf{z}^{(m)}, W)

where F\vert\vert\cdot\vert\vert_F is the Frobenius norm of a matrix, which can be written as the sum of the squared Euclidian norms of the rows of the matrix. The factor 12\frac{1}{2} is just there to simplify the expression of the derivative. The matrix X^\hat{X} contains the vectors x^(i)=Wz(i)\mathbf{\hat{x}}^{(i)} = W \mathbf{z}^{(i)}, i=1,...,mi = 1, ..., m, in its rows and its rank is indeed lower or equal to rr. Using the low-rank approximation theorem, the solution is

x^(i)z1(i)w1+...+zr(i)wr=σ1Ui1v1+...+σrUirvr\begin{aligned} \mathbf{\hat{x}}^{(i)} &\equiv z_1^{(i)}\mathbf{w}_1 + ... + z_r^{(i)}\mathbf{w}_r \\ &= \sigma_1 U_{i1} \mathbf{v}_1 + ... + \sigma_r U_{ir} \mathbf{v}_r \end{aligned}

where \equiv means that the equality holds by definition. From the fact that WW is constant for all data points and that it is z(i)\mathbf{z}^{(i)} which depends on the input data x(i)\mathbf{x}^{(i)}, we can identify zj(i)z^{(i)}_j with σjUij\sigma_j U_{ij} and the columns of WW with the vectors v1,...,vr\mathbf{v}_1, ..., \mathbf{v}_r. Thus, we have obtained what we wanted for WW, and actually for z\mathbf{z} too since

WTx(i)=WT(Ui,:ΣVT)T=[σ1Ui1σrUir]z(i)W^T \mathbf{x}^{(i)} = W^T (U_{i,:}\Sigma V^T)^T = \begin{bmatrix}\sigma_1U_{i1} \\ \vdots \\ \sigma_rU_{ir}\end{bmatrix} \equiv \mathbf{z}^{(i)}

If we combine this result with the intuition from the previous section, projecting data points on the subspace of dimension rr which captures as much variance of the data as possible minimizes the average Euclidean norm of the reconstruction error. In this sense z\mathbf{z} is the best linear representation of dimension rr of the data. This is particularly useful when the data is high-dimensional since with rnr \ll n, we have a low-dimensional representation that can be useful to visualize the data for example.

Dynamical Equations

Now it’s time to derive the dynamical equations of a neural network computing PCA. We will need 3 groups of neurons: the input units for the input data x\mathbf{x}, the internal units for the representation z\mathbf{z}, and the output units for the reconstructed data x^\mathbf{\hat{x}}. As we will see, these neurons are connected together with weights given by the matrix WW. Firstly, we need to specify the dynamics of the input units which represents the external input of the neural network. We will consider that a data point is drawn randomly from the data set and that it is presented to the neural network for a fixed amount of time τ\tau.

Fast dynamics

The output units compute a reconstruction x^\mathbf{\hat{x}} of the input from the internal state z\mathbf{z} of the network, and the transformation is a linear one specified by the weights WW. A simple dynamical model that achieves this is given by

τfast dx^dt=x^+Wz\tau_{fast} \ \dfrac{d\mathbf{\hat{x}}}{dt} = -\mathbf{\hat{x}} + W\mathbf{z}

since the equilibrium solution dx^dt=0\frac{d \mathbf{\hat{x}}}{dt}=0 is x^=Wz\mathbf{\hat{x}} = W \mathbf{z}. The time constant τfast\tau_{fast} controls the rate of convergence towards the equilibrium. Next, to obtain the dynamical equations for z\mathbf{z} and WW, we will use the optimization formulation of the previous section together with the following theorem.

Theorem: If a dynamical system can be written as dxdt=V(x)\frac{d\mathbf{x}}{dt} = - \nabla V(\mathbf{x}) then

dVdt(x(t))=V(x(t))dxdt(t)=V(x(t))20\frac{dV}{dt}(\mathbf{x}(t)) = \nabla V(\mathbf{x}(t)) \cdot \frac{d\mathbf{x}}{dt}(t) = -||\nabla V(\mathbf{x}(t))||^2 \leq 0

and dVdt=0\frac{dV}{dt} = 0 if and only if V=0\nabla V = 0.

The function VV can only decrease on the trajectories of the system and the state converges towards a minimum of VV. This is the equivalent of the gradient descent algorithm but in continuous time. If we take VV to be the objective that we are trying to minimize to compute the solution of PCA, its derivatives give us dynamical equations for z\mathbf{z} and WW.

Let’s denote by x(i)\mathbf{x}^{(i)} the current input. We would like the internal units z(i)\mathbf{z}^{(i)} to compute the orthogonal projection of the input onto the columns of WW, as it is the case in PCA. But we have to be careful because here we have a dynamical system where the units and the weights co-evolve, that is the weights WW are dynamically learned while the internal representation z(i)\mathbf{z}^{(i)} is dynamically computed. Thus, the columns of WW will not be necessarily orthonormal especially if the weights are initialized randomly, and the projection cannot be computed as we did before . If we take the gradient with respect to z(i)\mathbf{z}^{(i)}of the reconstruction error RR defined in (1) and set it to zero, we obtain

z(i)R=WT(Wz(i)x)=0\nabla_{\mathbf{z}^{(i)}}R = W^T(W\mathbf{z}^{(i)} - \mathbf{x}) = 0 z(i)=(WTW)1WTx(i)\mathbf{z}^{(i)} = (W^TW)^{-1}W^T \mathbf{x}^{(i)}

which is the correct formula for the orthogonal projection on non orthonormal vectors (if WTW=IW^TW=I we recover z(i)=WTx(i)\mathbf{z}^{(i)} = W^T\mathbf{x}^{(i)}). Using the theorem above, the following dynamics for the internal units will decrease the reconstruction error

τfast dzdt=WT(xx^)\tau_{fast} \ \frac{d\mathbf{z}}{dt} = W^T(\mathbf{x}-\mathbf{\hat{x}})

And the time scale of the dynamics should be faster than the time scale of the input units, i.e. τfastτ\tau_{fast} \ll \tau, so that the network has the time to compute the projection of the current input before it changes. Note that theoretically it requires an infinite amount of time to reach exactly the equilibrium but here we are only interested in an approximate solution and so being close to it is enough.

Slow dynamics

It remains to determine the dynamics of the weights. Similarly to what we did with the internal units, we can take the derivative of the reconstruction error with respect to WW,

RWkj=1mi=1m(x^k(i)xk(i))zj(i)\frac{\partial R}{\partial W_{kj}} = \frac{1}{m}\sum_{i=1}^m(\hat{x}^{(i)}_k - x^{(i)}_k)z^{(i)}_j

and use it to obtain a dynamical equation that would change the weights so as to minimize the reconstruction error, and so approximating PCA. The difference here is that we have a sum over the data points in the gradient which is problematic because only one data point at a time is shown to the network. The trick is to simply replace this sample average by a time average of the gradient evaluated at one data point sampled randomly. This is achieved by setting the time scale of the dynamics of the weights to a value much larger than the time scale at which the input data x\mathbf{x} is changed. Thus, we get the following dynamical equation

τslow dWdt=(xx^)zT\tau_{slow} \ \frac{dW}{dt} = (\mathbf{x}-\mathbf{\hat{x}})\mathbf{z}^T

where ττslow\tau \ll \tau_{slow}. This trick is actually well-known in machine learning as it is very similar to stochastic gradient descent where the expectation of the gradient is replaced by the gradient evaluated at one data point sampled randomly from the data set. The slow time scale would then correspond to a small learning rate. The slow dynamics of the weights make them sensitive to the data distribution instead of the details of individual data points. This means that the weights capture the statistical regularities, the patterns in the activity of the neurons. The dynamical equation above is an example of a Hebbian learning rule where the weights change according to the correlation between the activity of the neurons they connect.

Fast and slow dynamics

To sum up, the dynamical equations of our network are

{τfast x^˙=x^+Wzτfast z˙=WT(xx^)τslow W˙=(xx^)zT\left\{ \begin{array}{l} \begin{aligned} \tau_{fast} \ \dot{\mathbf{\hat{x}}} &= -\mathbf{\hat{x}} + W\mathbf{z}\\ \tau_{fast} \ \dot{\mathbf{z}} &= W^T(\mathbf{x}-\mathbf{\hat{x}})\\ \tau_{slow} \ \dot{W} &= (\mathbf{x}-\mathbf{\hat{x}})\mathbf{z}^T \end{aligned} \end{array} \right.

and x\mathbf{x} is sampled randomly from the data set at intervals τ\tau, with τfastττslow\tau_{fast} \ll \tau \ll \tau_{slow}. The following diagram illustrates the interactions between the different units of the network.

The important insight to remember from the derivation of these equations is that the dynamics of the internal state and the dynamics of the weights are governed by the same objective function, the difference resides in their respective time scale.

Simulations

Gaussian data

Let’s come back to the animation at the beginning of this post which is a visualization of our dynamical system computing PCA for Gaussian data in R2\mathbb{R}^2. The blue dots represent the whole data set, while the black dot corresponds to the input data that is currently presented to the network. The cross shows the reconstructed data obtained from the projection of x\mathbf{x} onto the vector WW. The two vectors v1\mathbf{v}_1 and v2\mathbf{v}_2 in gray just indicate the two first components of the data. Note that I have not given any theoretical results on how to choose the time scales τfast\tau_{fast}, τ\tau, and τslow\tau_{slow}. What we can do is to inspect visually the dynamics of the system and see if it has the expected behaviour. For example on the left of the figure below, it seems that the slow time scale τslow\tau_{slow} is too fast as individual data points have a quite big influence on the weights. The example on the right seems to be better.

Fast and slow dynamics approximating PCA on Gaussian data

MNIST

Now, let’s see how it performs on images of digits. We will consider images of zero and one taken from the MNIST dataset. Each image has 28x28 pixels and so the input vector x\mathbf{x} lies in a space of dimension 784, which is much greater than the 2 dimensions in the example above. It is not possible to visualize directly what is happening in this high-dimensional space, instead what we can do is to project the whole space on the two first components v1\mathbf{v}_1, v2\mathbf{v}_2 of the data set, computed with the SVD decomposition of XX. It gives us a window through which we can look at the high-dimensional dynamics and it allows us to see how close the columns w1\mathbf{w_1}and w2\mathbf{w}_2 of WW are to v1\mathbf{v}_1 and v2\mathbf{v}_2. If you look at the plots below, you will see that the vectors v1\mathbf{v}_1, v2\mathbf{v}_2 are shown in gray and in orange is shown the projections of w1\mathbf{w}_1, w2\mathbf{w}_2 onto the subspace spanned by these two vectors. If a projected orange segment has exactly the same length than the segments in gray, it means that the corresponding vector w\mathbf{w} belongs to the subspace spanned by v1\mathbf{v}_1, v2\mathbf{v}_2 and the longer its length, the more aligned it is with it.

The animation on the left shows the first 500 steps of a simulation of 4000 iterations, and the one on the right shows the last 500 steps. We can observe that at the beginning, the columns of WW which are initialized randomly are not at all aligned with the two first components, and so the internal representation z\mathbf{z} is not informative at all. As the network sees more and more images of digits, the weights learn to capture the statistical regularities in the data and at the end, the internal representation extracts meaningful information which allows us to distinguish the two digits. We can observe this on the animation on the right where the cross is correctly on the right for images of one and on the left for images of zero. However, it seems that the weights have more difficulties to capture the second component v2\mathbf{v}_2 of the data.

Fast and slow dynamics approximating PCA on MNIST

Conclusion

The goal of this post was not the provide rigorous results about convergence, uniqueness, etc, but to provide some intuition on how the dynamics of a neural network can approximately compute an algorithm. Even though the computation that our dynamical system realizes is way simpler than the computations that the brain realizes, the principle behind it might be fundamental to understand brain dynamics. It is the idea that the multi-scale dynamics is optimizing one objective function which is some kind of measure of the prediction error between inputs and predictions computed from internal representations. This is the idea of predictive coding or more generally, of the free energy principle where action, perception and learning are governed by the same objective function. In this framework, the fast dynamics of the internal units of our network is called perception and the slow dynamics of the weights, learning.