Equivariant Neural Networks
By: Gijs Bellaard
Convolution Neural Networks
Lets first motivate why we are interested in standard planar convolution/correlation by showing how they naturally arise in the context of machine learning. Suppose I want to make a feed-forward neural network that takes in a photograph of an apple in an environment and returns a segmentation of that apple. It makes sense that if I translate the photograph and give this translated version to the neural network, that I want the network to return a translated version of the original segmentation. In other words we want the network to be equivariant with respect to translations of the input image.
Lets put this idea into more mathematical terms. We model the photograph and segmentation as real-valued functions on \(\bbR^2\). Let \(F(\bbR^2,\bbR)\) denote this space of functions. The neural network is then a mapping \(\cN : F(\bbR^2,\bbR) \to F(\bbR^2,\bbR)\). We consider a translation \(\bv\) an element of the mathematical group \((\bbR^2,+)\), and this group acts naturally on points \(\bx \in \bbR^2\): \begin{equation} \bv \act \bx := \bx + \bv \end{equation} and this action extends in the normal way to an action on real-valued functions \(f\) on \(\bbR^2\): \begin{equation} (\bv \act f)(\bx) = f(\bx - \bv) \end{equation} The desired translation equivariance property can now be written as: \begin{equation} \cN( \bv \act f) = \bv \act \cN(f) \end{equation} for every \(f : \bbR^2 \to \bbR\) and \(\bv \in \bbR^2\), or more abstractly: \begin{equation} \label{eq:abstract_equivariance} \cN \circ (\bv \act) = (\bv \act) \circ \cN, \end{equation} where we now consider \(\bv \act : F(\bbR^2,\bbR) \to F(\bbR^2,\bbR) \) also as a mapping from functions to functions. The figure below illustrates the situation.
But how does one make a - possibly very complicated - neural network equivariant? Well, it is easier to consider just the constituent parts of the network. Suppose our feed-forward network can be separated into layers \(L_i : F(\bbR^2,\bbR) \to F(\bbR^2,\bbR)\), that is \(\cN = L_n \circ L_{n-1} \circ \cdots \circ L_1\). If we make every layer \(L_i\) equivariant then the whole network is also equivariant, this can be quickly understood using \eqref{eq:abstract_equivariance}: \begin{equation} \begin{split} \cN \circ (\bv \act) &= L_n \circ L_{n-1} \circ \cdots \circ L_2 \circ L_1 \circ (\bv \act) \\ &= L_n \circ L_{n-1} \circ \cdots \circ L_2 \circ (\bv \act) \circ L_1 \\ &\hphantom{==}\vdots\\ &= (\bv \act) \circ L_n \circ L_{n-1} \circ \cdots \circ L_2 \circ L_1 \\ &= (\bv \act) \circ \cN \end{split} \end{equation}
Linear layers are ubiquitous in machine learning, so let us consider a linear translation equivariant layer. One easy translation equivariant linear operation is cross-correlation, or what the machine learning community wrongly calls convolution. Cross-correlation \(\star\) is defined as: \begin{equation} \label{eq:cross_correlation_R2} (k \star f)(\bx) = \int_{\bbR^2} k(\by - \bx)\ f(\by)\ d\by \end{equation} notice that this subtly differs from convolution: \begin{equation} (k * f)(\bx) = \int_{\bbR^2} k(\bx - \by)\ f(\by)\ d\by \end{equation} The interpretation of cross-correlation is straightforward: we have a filter/kernel \(k\) that we move over the image \(f\) and we calculate the response. Convolution differs from cross-corelation in that the filter/kernel is first inverted and then moved over the image. What most machine learning libraries call convolution is actually implemented as cross-correlation! In the figure below one can see how correlation(in the one dimensional case \(\bbR\)) with a gaussian kernel results in smoothing the input function.
Intuitively it makes sense that cross-correlation is translation equivariant: the same filter is applied everywhere so how could it not be? Lets check if cross-correlation is really translation equivariant mathematically, i.e if \begin{equation} (\bv \act) \circ (k \star) = (k \star) \circ (\bv \act) \end{equation} Let's apply both sides to some dummy function \(f\) and evaluate it at the dummy position \(\bx\): \begin{equation} \begin{split} (\bv \act (k \star f))(\bx) &= (k \star f)(\bx - \bv) \\ &= \int_{\bbR^2} k(\by - (\bx - \bv))\ f(\by)\ d\by \\ &= \int_{\bbR^2} k(\by' - \bx)\ f(\by' - \bv)\ d\by' \\ &= \int_{\bbR^2} k(\by' - \bx)\ (\bv \act f)(\by')\ d\by' \\ &= (k \star (\bv \act f))(\bx) \end{split} \end{equation} where we applied the substitution \(\by = \by' - \bv\). Indeed, everything works out as expected: cross-correlation is translation equivariant.
So, if we want to make a neural network that is translation equivariant, one easy way to do this is to create it from layers that apply a cross-correlation \(\star\) with some kernel \(k\), where the kernel can be chosen freely. The parameters that are learned during training are the parameters that determine the kernel.
Roto-translation Equivariance
Let us return to our segmentation of a photograph of an apple. We said that we want the neural network to be equivariant to translations. Makes sense, but there is an additional set of symmetries we might also desire: rotations. So actually we want the network to be equivariant to what we call the roto-translation group.
Let \(SE(2) = \bbR^2 \times SO(2) \subset \bbR^2 \times \bbR^{2 \times 2}\) denote the roto-translation group. An element \(g = (\bv, \bR) \in SE(2)\) of this group consists of a translation \(\bv\) and a rotation matrix \(\bR\). It acts on \(\bbR^2\) in the standard way: \begin{equation} g \act \bx = (\bv, \bR) \act \bx = \bR \bx + \bv \end{equation} The action of this group on \(\bbR^2\) extends straightforwardly to an action on real-valued functions on \(\bbR^2\): \begin{equation} (g \act f)(\bx) = f(g^{-1} \act \bx) = f( \bR^{-1}(\bx - \bv)) \end{equation}
Now we already saw that the cross-correlation we defined in \eqref{eq:cross_correlation_R2} is equivariant with translations. Maybe it turns out that it is already equivariant to roto-translations? Let us check. So let's first manipulate \(g \act (k \star f)\) in a similar way as before and do the substitution \(\by = g^{-1} \act \by'\) \begin{equation} \begin{split} (g \act (k \star f))(\bx) &= (k \star f)(g^{-1} \act \bx)\\ &= \int_{\bbR^2} k(\by - g^{-1} \act \bx)\ f(\by)\ d\by \\ &= \int_{\bbR^2} k(g^{-1} \act \by' - g^{-1} \act \bx)\ f(g^{-1} \act \by')\ d\by'\\ &= \int_{\bbR^2} k(\bR^{-1}(\by' - \bx))\ f(g^{-1} \act \by')\ d\by' \end{split} \end{equation} We are almost there if we assume that \(k(\bR^{-1} \bx) = k(\bx)\) for any \(\bR^{-1}\), i.e. the kernel is rotationally symmetric: \begin{equation} \begin{split} &= \int_{\bbR^2} k(\by' - \bx)\ f(g^{-1} \act \by')\ d\by' \\ &= \int_{\bbR^2} k(\by' - \bx)\ (g \act f)(\by')\ d\by' \\ &= (k \star (g \act f))(\bx) \end{split} \end{equation}
So, we can create an roto-translation equivariant operation by performing cross-correlation with an rotationally symmetric, also known as isotropic, kernel. But in practice this constraint is way too restrictive for use in convolution neural networks! Before we address this problem, let's first consider a slightly different situation.
Spherical CNNs
Suppose we have a real-valued function on the sphere \(S^2 = \{ x \in \bbR^3 \mid \| x \| = 1 \} \subset \bbR^3\), and we want to generalize cross-correlation to \(S^2\) such that it is equivariant to three dimensional rotations \(SO(3) \subset \bbR^{3 \times 3}\). We can't just directly re-implement what we wrote in \eqref{eq:cross_correlation_R2}: there is no making sense of subtracting two points on \(S^2\). Also, we notice that we made an implicit assumption previously. We assumed/imagined that when we moved the kernel \(k\) around that it has some "center": the origin of \(\bbR^2\). We can't do the same on \(S^2\): there is no canonical origin. Now we also need to define integration on \(S^2\), but that is no problem: we use the canonical measure \(\mu\) on \(S^2\) that we borrow from the ambient space \(\bbR^3\).
Before we continue with \(S^2\) we reconsider what we are actually doing when performing a cross-correlation on \(\bbR^2\). So, the kernel \(k\) has an implied center, lets call it \(\bx_0\), which when working in \(\bbR^2\) is usually the origin. When moving the filter to another position \(\bx\) we actually mean translating it with some vector \(\bv \in \bbR^2\) such that its center \(\bx_0\) is moved to \(\bx\). In other words, our cross-correlation on \(\bbR^2\) can be equivalently written as: \begin{equation} (k \star f)(\bx) = \int_{\bbR^2} (\bv_{\bx_0 \to \bx} \act k)(\by)\ f(\by)\ d\by \end{equation} where \(\bv_{\bx_0 \to \bx}\) is the unique translation such that \(\bv_{\bx_0 \to \bx} \act \bx_0 = \bx\).
Now this equivalent way of looking at cross-correlation does generalize: we can exploit our desired rotational symmetries to transport our kernel. We presuppose that our kernel has some implied center \(p_0 \in S^2\) and we define: \begin{equation} \label{eq:cross_correlation_s2} (k \star f)(p) = \int_{S^2} (g_{p_0 \to p} \act k)(q)\ f(q)\ d\mu(q) \end{equation} where \(g_{p_0 \to p} \in SO(3)\) is a roto-translation for which \(g_{p_0 \to p} \act p_0 = p\).
However, there is a problem: there are multiple ways to choose \(g_{p_0 \to p} \). We fix this issue by only considering kernels that do not depend on this choice. That means we should have that \(g_{p_0 \to p} \act k\) is the same for every possibly \(g_{p_0 \to p}\) one can choose. This is equivalent to requiring that \begin{equation} \label{eq:symmetry_requirement_kernel} s_{p_0} \act k = k \end{equation} for any \(s_{p_0} \in SO(3)\) such that \(s_{p_0} \act p_0 = p_0\). We say that the kernel must be invariant under the stabilizer subgroup of its center \(p_0\). Or in other words: the kernel must be symmetric about \(p_0\).
Okay, so we have this definition of cross-correlation on \(S^2\), but is it actually equivariant with respect to rotations? Let us check: \begin{equation} \label{eq:s2_cross_correlation_equivariance} \begin{split} (g \act (k \star f) )(p) &= (k \star f)(g^{-1} \act p)\\ &= \int_{S^2} (g_{p_0 \to (g^{-1} \act p)} \act k)(q)\ f(q)\ d\mu(q)\\ &= \int_{S^2} (g^{-1} \act g_{p_0 \to p} \act k)(q)\ f(q)\ d\mu(q)\\ &= \int_{S^2} (g_{p_0 \to p} \act k)(g \act q)\ f(q)\ d\mu(q)\\ &= \int_{S^2} (g_{p_0 \to p} \act k)(q')\ f(g^{-1} \act q')\ d\mu(q')\\ &= \int_{S^2} (g_{p_0 \to p} \act k)(q')\ (g \act f)(q')\ d\mu(q')\\ &= (k \star (g \act f) )(p) \end{split} \end{equation} So, indeed, it is equivariant!
And we now see another implicit assumption we made earlier: we assumed that the measure \(\mu\) is invariant under the considered group of symmetries, otherwise the substitution step in \eqref{eq:s2_cross_correlation_equivariance} is not that easily made. Mathematically we mean that: \begin{equation} \label{eq:invariant_measure} \mu(S) = \mu(g \act S) \end{equation} for every measurable set \(S \subset S^2\) and group element \(g \in SO(3)\), which has the corollary that: \begin{equation} \int_{S^2} f(p)\ d\mu(p) = \int_{S^2} (g \act f)(p)\ d\mu(p) \end{equation} Luckily we borrowed the measure on \(S^2\) from \(\bbR^3\), which is indeed invariant under rotations.
Group CNNs
Our discussion of Spherical CNNs now easily generalizes to any Lie group \(G\) of desired symmetries and homogeneous space \(M\) because we already wrote down everything so abstractly. A homogeneous space is basically a manifold on which the group \(G\) acts transitively, that means any point \(q\) can be reached from any other point \(p\) using a group element \(g\) : \(g \act p = q\). We need this property to successfully move the kernel over the space. An equivariant cross-correlation on \(M\) is what we already wrote down in \eqref{eq:cross_correlation_s2}: \begin{equation} \label{eq:group_cross_correlation} (k \star f)(p) = \int_{M} (g_{p_0 \to p} \act k)(q)\ f(q)\ d\mu(q) \end{equation} which works out so long as the measure \(\mu\) is invariant w.r.t \(G\) \eqref{eq:invariant_measure}, and the kernel is symmetric about its center \(p_0\) \eqref{eq:symmetry_requirement_kernel}.
Let's return to our previous problem of making a \(SE(2)\) equivariant cross-correlation for functions on \(\bbR^2\). The problem was that the only correlations that are allowed then are with isotropic kernels. In practice we see that this constraint is too restrictive, so we need some way around it. One way to solve this issue is to simply stop working within the restraints of \(\bbR^2\)! We instead choose to lift the function on \(\bbR^2\) to a function on, for example, the higher dimensional space of \(SE(2)\) in an equivariant manner, and then perform equivariant cross-correlations there. A benefit of this approach is that the kernel constraint \eqref{eq:symmetry_requirement_kernel} becomes trivial then. This is because the stabilizer subgroup of a group element is nothing but the identity element. There is however a huge cost to be paid: increasing the domain to a higher dimensional space curses us with a very large increase in computer memory usage in practice.
Riemannian Manifold CNNs
The main motivation of Manifold CNNs is the attempt to generalize cross-correlation \eqref{eq:group_cross_correlation} even further to any Riemannian manifold \((M, \cG)\). To keep things tangible I suggest keeping the sphere \(S^2\) in mind as a prototypical example. In general a Riemannian manifold does not have global symmetries, so the first obstacle we encounter is the question of how do we move the filter/kernel \(k\) over the manifold \(M\)?
The first thing we do on our way to be able to confidently move the kernel over the manifold is to change its domain. In the examples we saw the kernel was always a function on the space itself, i.e. a function on the manifold. We change this to be the tangent space \(T_{p_0}M\) at some implied center \(p_0 \in M\).
To move the kernel around we can now use a standard tool in differential geometry: parallel transport. The parallel transport \(\G_{\g, s \to t} : T_{\g(s)}M \to T_{\g(t)}M\) with respect to a curve \(\g\) gives us a way to transport tangent vectors around the manifold. The figure below is an example of parallel transport of a vector on a two dimensional Riemannian manifold. To move the kernel we extend the definition of parallel transport to functions on tangent spaces: \begin{equation} (\G_{\g, s \to t} k)(v) = k(\G_{\g, t \to s} v) \end{equation} where \(k : T_{\g(s)}M \to \bbR\) is some kernel at \(\g(s)\) and \(\G_{\g, s \to t} k : T_{\g(t)}M \to \bbR\) is the transported kernel defined at \(\g(t)\).
But even more problems start popping up. Namely, there is no canonical parallel transport on a general manifold. Luckily we have a Riemannian manifold which does have a canonical parallel transport: the Levi-Civita connection. But even then the transport is in general not the same for any of the curves \(\g\) one can choose!
So again, we sweep this issue under the rug by assuming that for any parallel transport w.r.t a curve \(\g\) the transported kernel is the same. This requirement can usually be satisfied by making the kernel invariant to parallel transport about the center \(p_0\). This means we want that: \begin{equation} \G_{\g, s \to t} k = k \end{equation} for any curve \(\g\) that goes from \(\g(s) = p_0\) back to \(\g(t) = p_0\), which is also called a loop. Once we have a kernel that satisfies this requirement we are allowed to write things like \(\G_{p_0 \to p} k\), with which we mean to perform any parallel transport with a curve \(\g\) that goes from \(p_0\) to \(p\).
The set of parallel transports \(\G_{\g, s \to t}\) of loops \(\g\) of \(p_0\) creates a mathematical group called the holonomy group \(H_{p_0}\). Using this terminology we can say that the kernel needs to be invariant with respect to the holonomy group \(H_{p_0}\). The holonomy group of a \(n\)-dimensional Riemannian manifold is always a subgroup of \(O(n)\), as the parallel transport of the Levi-Civita connection preserves the Riemannian metric. In the case of \(M = \bbR^2\) the holonomy group consists of only the identity, so the requirement boils down to nothing. In the case of \(M = S^2\) the holonomy group is \(SO(2)\), so the kernel needs to be rotationally symmetric. In the case of a Mobius strip the holonomy group is the reflection across the band, and the kernel needs to be symmetric across the band.
Great, we can now confidently move the kernel about the manifold, but how do we interpret the integral of the cross-correlation now? The kernel is now a function on \(TM\), but the function \(f\) is still a real-valued function on the manifold \(M\)... This is where the Riemannian exponential map comes into play. The Riemannian exponential map \(\exp : TM \to M\) is a mapping from the tangent bundle to the manifold itself. Intuitively, \(p = \exp(v)\) is the place you end up when one starts walking into the direction of \(v\) for one unit of time. In the figure below an example is given of the Riemannian exponential map on the sphere \(S^2\).
We are now finally ready to create our first definition of the cross-correlation on a Riemannian manifold: \begin{equation} (k \star f)(p) = \int_{T_{p}M} (\G_{p_0 \to p} k)(v)\ f(\exp(v))\ d\mu(v) \end{equation} where \(\mu\) is the induced measure on \(T_p M\) that we get from the Riemannian metric \(\cG_p\) at \(p\).