It’s often of interest to compute gradients in function space. If we were in a standard finite-dimensional space like \({\mathbb{R}}^{n}\), this would be quite simple: just use an off-the-shelf automatic differentiation engine/tool (e.g. JAX, Torch, etc.). But in function spaces these tools are no longer immediately applicable, at first glance. Fortunately, we can easily reuse them for functional gradients in certain cases – as I will show below.
Note: For simplicity, we will only concern ourselves with computing the gradients of functionals here. But things mostly generalize.
A quick primer on functional gradients
Let \(\mathcal{H}\) be a Hilbert space – i.e., a complete inner product vector space. Then we can define the Gateaux derivative – a generalization of the usual directional derivative – as follows:
Definition (Gateaux derivative in \(\mathcal{H}\)). Consider a functional \(L:\mathcal{H} \rightarrow {\mathbb{R}}\). Then the Gateaux derivative of \(L\) at \(h \in \mathcal{H}\) in direction \(d \in \mathcal{H}\) – denoted \(DL(h;d)\) – is given by \[DL(h;d) ≔ \lim\limits_{\delta \rightarrow 0}\frac{L(h + \delta d) - L(h)}{\delta}.\] \(L\) is said to be Gateaux differentiable if the limit exists for all \(h\) and \(d\).
The Gateaux derivative is often also required to be linear in the direction \(d\), in which case we can say that \(L\) is linearly Gateaux differentiable. (Some authors already define Gateaux differentiability to require linearity.)
It is often helpful to rewrite this as \[\lim\limits_{\delta \rightarrow 0}\frac{L(h + \delta d) - L(h)}{\delta} = \left\lbrack \frac{d}{d\delta}L(h + \delta d) \right\rbrack_{\delta = 0},\] which are trivially equivalent when the right-hand-side exists.
However, so far all this gives us is the directional derivative; we want a “gradient vector”. To do this we use the Hilbert space structure, by the means of the Riesz representation theorem:
Theorem (Riesz representation theorem). Let \(\mathcal{H}\) be a real Hilbert space. Then every continuous linear functional \(F:\mathcal{H} \rightarrow {\mathbb{R}}\) can be written in the form \[F(h) = \langle r_{F},h\rangle,\] for some \(r_{F} \in \mathcal{H}\).
Essentially, the theorem lets us rewrite any linear functional over a hilbert space as an inner product of the argument with a fixed vector. Applying it to the Gateaux derivative of a linearly Gateaux differentiable function, we obtain
\[DL(h;d) = \left\langle \nabla L(h),d \right\rangle;\]
the suggestively notated \(\nabla L(h) \in \mathcal{H}\) is then termed the gradient of \(L\) at \(h\). This coincides exactly with the usual gradient in \({\mathbb{R}}^{n}\), but is immediately also defined in general Hilbert spaces.
It’s worth highlighting that the gradient \(\nabla L(h)\) fundamentally depends on the Hilbert space \(\mathcal{H}\), and in particular its inner product \(\langle \cdot , \cdot \rangle\). Indeed, it is very much possible to have the same set \(\mathcal{H}\) with two differeent inner products, with each leading to a different gradient: a simple example would be \(L^{2}\) space with different (but equivalent) measures. Therefore, the gradient must always be stated in relation to a particular Hilbert space and inner product.
Reproducing Kernel Hilbert Spaces (RKHSs)
Among all Hilbert spaces, there are some that satisfy a very handy additional structure: reproducing kernels. For simplicity, let us consider just the case of Hilbert spaces of functions from some set \(\mathcal{X}\) to the reals \(\mathbb{R}\).
Definition (Reproducing Kernel Hilbert Space). A Hilbert space \(\mathcal{H}\) of functions \(\mathcal{X} \rightarrow {\mathbb{R}}\) is a Reproducing Kernel Hilbert Space (RKHS) if for every \(x \in \mathcal{X}\) there exists some \(K_{x} \in \mathcal{H}\) (the reproducing kernel) such that \[\langle h,K_{x}\rangle = h(x)\quad\text{ for all }h \in \mathcal{H}.\]
I.e., RKHSs are function spaces in which evaluation at a point can be reframed as an inner product. It’s worth noting that RKHSs admit many equivalent definitions; see e.g. the definition on Wikipedia.
Example (\(L^{2}\) is not an RKHS). Perhaps the most canonical function space with Hilbert structure is \(L^{2}\) with Lebesgue measure. However, it is “too rich” to be an RKHS: see, e.g., this StackExchange post.
□Example (\(H^{1}\) Sobolev space is an RKHS). While usual \(L^{2}\) space is not an RKHS, many Sobolev spaces (which can be generally thought as the canonical function spaces where derivatives are defined) are actually RKHSs. For example, the \(H^{1}\) space of differentiable functions from \(\mathbb{R}\) to \(\mathbb{R}\) is an RKHS.
□Example (\({\mathbb{R}}^{n}\) is an RKHS). The usual \({\mathbb{R}}^{n}\) space, with elements viewed as a functions from \(\left\{ 1,\ldots,n \right\}\) to \(\mathbb{R}\), is an RKHS: for each \(x \in \left\{ 1,\ldots,n \right\}\), simply produce the canonical basis vector \(\mathbf{e}_{x}\), which is zero in all entries other than \(x\), where it is one.
□A simple autodiff in an RKHS
We are now ready to combine all our ingredients. Ideally, we would have some way of doing autodiff to compute the full functional gradient \(\nabla L(h)\). However this is generally an infinite-dimensional object (e.g. in function spaces), and thus cannot be fully stored in memory. So let us instead content ourselves with knowing the functional gradient at a finite set of points \(\left\{ x_{1},\ldots,x_{n} \right\} \subset \mathcal{X}\).
Because we are in an RKHS, we can write, for each \(i = 1,\ldots,n\), \[\nabla L(h)\left( x_{i} \right) = \langle\nabla L(h),K_{x_{i}}\rangle = DL\left( h;K_{x_{i}} \right),\] the latter of which can be expanded as \[DL\left( h;K_{x_{i}} \right) = \lim\limits_{\delta \rightarrow 0}\frac{L\left( h + \delta K_{x_{i}} \right) - L(h)}{\delta} = \left\lbrack \frac{d}{d\delta}L\left( h + \delta K_{x_{i}} \right) \right\rbrack_{\delta = 0}.\] Hence, \[\nabla L(h)\left( x_{i} \right) = \left\lbrack \frac{d}{d\delta}L\left( h + \delta K_{x_{i}} \right) \right\rbrack_{\delta = 0}.\]
This is great! We have reduced computing the functional gradient at a
point \(x_{i}\) to the computation of a
single directional derivative, which in turn we reduce to the
computation of a single scalar derivative. This can trivially
be done with any existing scalar autodiff implementation. Here is a
basic implementation in JAX, mimicking the API of
jax.grad:
import jax
import jax.numpy as jnp
# functional : H -> RR
# kernel : X times X -> RR
# f : H
# x : X
def rkhs_grad(functional, kernel):
def gradient_operator(f):
def gradient_output(x):
return jax.grad(
lambda delta: functional(lambda x_: f(x_) + delta * kernel(x, x_))
)(0.0)
return gradient_output
return gradient_operatorAnd we can use it as follows:
import matplotlib.pyplot as plt
xs = jnp.linspace(-4, +4, 200)
ys = jnp.sinc(xs)
def functional(f):
# MSE loss
preds = jax.vmap(f)(xs)
return jnp.mean(0.5 * (preds - ys)**2)
def h1_kernel(x, y):
# kernel for H^1(R); cf. https://hal.science/hal-03836621v1/document
return 0.5 * jnp.exp(-jnp.abs(x - y))
h = lambda x: 0.0 # just some function to compute the gradient on
grad = rkhs_grad(functional, h1_kernel)(h)
plt.plot(xs, jax.vmap(grad)(xs), label=r"$\nabla L(h)$")That said, it is worth noting that this implementation is a bit inefficient in the sense that it will not reuse much of the autodiff machinery and intermediates across the computation of various \(x_{i}\)s. Improving this is left as an exercise to the reader.
If for some reason you'd like to cite this post, you can use this BibTeX entry (click to copy to clipboard).