Introduction

In generative models, such as VAEs and Diffusion, we would like to learn the parameters of the distribution of some data in order to generate novel examples of the data itself.
As such, we would need our architecture to be stochastic (random) in nature so that we may sample from the learned distribution to generate new data. This means that in between our layers, there would be layer(s) that create random values and backpropagating over those random values doesn’t make sense, nor it is feasible.

This poses a big problem for us when trying to optimize such networks: How do you optimize over layer(s) that are stochastic in nature?
One of the solutions comes in the form of a neat trick, aptly named the Reparametrization Trick.

To better understand the theoretical basis for the reparametrization trick, let’s look at the case of VAEs.
The loss function for the VAE, as desribed in a previous post, is $ \mathrm{ELBO} $: $$ \mathcal{L}_{\theta, \phi} = \mathbb{E}_{q_\phi(z\mid x)}[\log p_\theta(x\mid z)] - \mathbb{E}_{q_\phi(z\mid x)}\left[\log\frac{q_\phi(z\mid x)}{p_\theta(z)}\right]$$

with $ p_\theta(z) $ being the prior probability — which is usually assumed to be a Standard Gaussian $\mathcal{N}(\mathrm{x};0,1)$. One of the main reasons for such an assumption is the ease of use of working with standard normal distribution, even though it may not always be accurate to make such an assumption.
The term $p_\theta(x\mid z)$ is the reconstruction error of the model (How good can the decoder part of the model predict the input $x$ given the latent variable $z$), while $q_\phi(z\mid x)$ is the encoder part of the model (What is the likelihood of $z$ being our latent variable given our input $x$).

Formalizing Loss

In the above formulation, the encoder part $q_\phi(z \mid x)$ is deterministic, we give the data $x$ and out we get parameters describing the distribution of $z$; however, the decoder part $p_\theta(x\mid z)$ is non-deterministic, since we wish to sample from distribution $z$ in order to create novel samples $x$.
Expanding expectation and simplifying the loss further we get: $$\begin{aligned} \mathcal{L}_{\theta, \phi} &= \int q_\phi(z\mid x)\log p_\theta(x\mid z)~dz - \int q_\phi(z\mid x)\log\frac{q_\phi(z\mid x)}{p_\theta(z)} ~dz\\ &= \int q_\phi(z\mid x)\left[\log p_\theta(x\mid z) - \log\frac{q_\phi(z\mid x)}{p_\theta(z)}\right]~dz\\ &= \mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x\mid z) - \log\frac{q_\phi(z\mid x)}{p_\theta(z)}\right]\\ &= \mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x\mid z) - \log q_\phi(z\mid x) + \log p_\theta(z)\right]\\ &= \mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ \end{aligned}$$ Given that our architecure has an encoder with parameters $\phi$ and a decoder with parameters $\theta$, we differetiate for both parameters: $$\begin{align} \nabla_\theta\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\theta\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\phi\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ \end{align}$$

Knowing the encoder to be determinisitc, finding the gradient equation for w.r.t parameters $\theta$ gives us: $$\begin{aligned} \nabla_\theta\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\theta\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ &= \mathbb{E}_{q_\phi(z\mid x)}\left[\nabla_\theta\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\right]\\ &= \mathbb{E}_{q_\phi(z\mid x)}\left[\nabla_\theta\log p_\theta(x, z)\right]\\ \end{aligned}$$ This means the derivative of the integral (expectation) here becomes the integral of the derivative. This is allowed since, thinking intuitively, as we take the derivative of the expression w.r.t. $\theta$, the range of the integral (which depends on $\phi$) doesn’t change and remain constant. Therefore there is no interaction between the derivative and integral operation here, and since $\displaystyle\lim_{x \to \infty}\left[f(x) \pm g(x)\right] = \lim_{x \to \infty}f(x) \pm \lim_{x \to \infty}g(x)$, we can change places between the integral and derivative here. For a more robust explanation, you can have a look at the Leibniz integral rule [1].

An important note here is that in practical implementations, the expectation is approximated using a Monte Carlo sampling process. This is just a fancy way of saying that we will sample $N$ number of times to approximate what the expected value would be: $$ \nabla_\theta\mathcal{L}_{\theta,\phi}(\mathrm{x}) \approx \frac{1}{N}\sum_{i=1}^N\nabla_\theta\log p_\theta(x, z) $$

Having found our loss w.r.t. $\theta$, let’s turn our attention to equation $(2)$.
Unfortunately, here we encounter two problems:

  • Since the derivative operation here (w.r.t. $\phi$) changes the range of the integral (the range of the integral becomes a function of the derivative), we cannot interchange them in the formulation as we did before; therefore, we have to find another way.
  • Our derivative this time around contains a term that is non-deterministic, and we do not know how to differetiate for a non-deterministic function.

Change of Variable

A clue on how to go about solving the two problems is given in a technique called Change of variable [2].
In its fundamental form, change of variable says that we can describe a function/variable in terms of another function/variable. For example, if we have a function $$f(x) = 2x^2 + x$$ we can rewrite it as $$f(x) = g(x, 2) + g(x, 1)$$ $$\mathrm{where}~g(x, y) = y\times x^y$$ This allows us to solve the original function $f(x)$ indirectly by solving $g(x, y)$ first and then transforming the values back to $f(x)$ formulation.

Since a probability density function is, well, a function, we can use the change of variable to make our job easier. In probability lingo, this would mean that we can describe our original PDF $p(\mathrm{x})$ (which is most likely a complex distribution) using a different PDF $\rho(\mathrm{\epsilon})$ (hopefully, one that’s easier to sample from) and then we find a way to transform the sampled values $\mathrm{\epsilon}$ back into the original distribution $p(\mathrm{x})$.

Furthermore, if we assume that our original function $p(\mathrm{x})$ has a normal distribution (which is the case with VAEs) with parameters $\mu$ and $\sigma$, then we can have our base distribution $\rho(\mathrm{\epsilon})$ be a standard normal, and our tranformation $\rho(\mathrm{\epsilon}) \mapsto p(\mathrm{x})$ be: $$ g(\epsilon, \theta) = \mu + \sigma\cdot\epsilon $$ This form of transformation is quite common in parametrized distributions; for more detail have a look through [3]. More importantly, notice that the transformation function is deterministic.

Statistical Magic

Using change of variable, we can now redefine our loss function $(2)$ as: $$\begin{aligned} \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\phi\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ \end{aligned}$$ $$ \mathrm{where}~z = g(\phi, \mathrm{x}, \epsilon)~\mathrm{and}~ \epsilon \sim \rho(\epsilon)$$ This essentially removes the non-determinism we had before and replaces it with a deterministic function $g(\phi, \mathrm{x}, \epsilon)$ that takes input from a distribution $\rho$. Practically, this means that will sample from $\rho(\epsilon)$ and run it through the decoder; later during optimization we only need to backpropagate until the function $g(\phi, \mathrm{x}, \epsilon)$ and no further.

However, one issue still persists; we cannot differentiate w.r.t the decoder function due to the interaction between the integral and the derivative (as mentioned before).
To ameliorate, let’s expand on the idea of change of variable. Imagine a random variable $\mathrm{X}$ with PDF $\mathrm{x} \sim p(\mathrm{x})$. Next, let’s have a deterministic function for the random variable, such that: $$g(\mathrm{x}): \mathrm{x} \mapsto \mathrm{y}$$ Now imagine going over all of the possible values of $\mathrm{X}$ and tranforming it with $g(x)$. The set of all possible values that results from $g(x)$, therefore, will be another random variable we’ll call $\mathrm{Y}$. Note that this formulation is essentially the same as the change of variable for $\mathrm{Y}$.
For example, if we have $g(x) = 2x$, then: $$ x \sim p(x) \\ g(x) = g(\cdots,1.2, \cdots,3.2,\cdots) = \cdots,2.4,\cdots, 6.4,\cdots $$ An important observation to be made here is that the probability of a value (such as $2.4$) being sampled from random variable $\mathrm{Y}$ is equal the probability of the original value ($1.2$) which was tranformed by the function $g(x)$. This is true because the function that generates the samples of the random variable $\mathrm{Y}$ is deterministic in all its parts except for the samples that it gets from the random variable $\mathrm{X}$ with probability $p(x)$. For a more rigorous proof, check out [4].

A side note: You might think that the function needs to be injective for this work; however, if can be shown that this property holds for non-injective functions as well. For example, if we have a non-injective function $g(x) = \mid x \mid$, and according to the base probability $\rho(2) = 0.1$, then the probabilities will add up in the end: $$ g(\cdots,-2,\cdots,2,\cdots)\cdot\rho(x) = \cdots,2\cdot\rho(2),\cdots,2\cdot\rho(2),\cdots = \cdots,2[\rho(2) + \rho(2)],\cdots $$ However, surjectivity is given in this case.


Equipped with this knowledge, let’s tackle the loss function $(2)$ again:

$$\begin{aligned} \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\phi\mathbb{E}_{q_\phi(z\mid x)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\phi\int\log p_\theta(x, z)~q_\phi(z\mid x) - \log q_\phi(z\mid x)~q_\phi(z\mid x)\\ \end{aligned}$$

We know that $ z = g(\phi, \mathrm{x}, \epsilon)~\mathrm{and}$ the probability of the original values $\epsilon$ is $\rho(\epsilon)$. From the previous “proof” we know that the probability of $z$ must also be $\rho(\epsilon)$. Therefore $q_\phi(z\mid x) = \rho(\epsilon)$:

$$\begin{aligned} \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\phi\int\log p_\theta(x, z)~\rho(\epsilon) - \log q_\phi(z\mid x)~\rho(\epsilon)\\ \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) &= \nabla_\phi\mathbb{E}_{\rho(\epsilon)}\left[\log p_\theta(x, z) - \log q_\phi(z\mid x)\right]\\ \end{aligned}$$ Finally, there is no interaction between the derivative and the integral; therefore we can interchange them and get our final loss function: $$ \nabla_\phi\mathcal{L}_{\theta,\phi}(\mathrm{x}) = \mathbb{E}_{\rho(\epsilon)}\left[-\nabla_\phi\log q_\phi(z\mid x)\right]\\ $$

Conclusion

There were a few moving parts here, but all of the derivations in this article sums up to this: If we have a stochastic layer in the neural network, such is the case for VAEs, then we can separate the sampling (stochastic) process from the gradient descent computation and have the sampling result be treated as another input to the network when backpropagating the gradients.

References

[1]Wikipedia contributors. (2023). Leibniz integral rule. Wikipedia. https://en.wikipedia.org/wiki/Leibniz_integral_rule
[2]Wikipedia contributors. (2023b). Change of variables. Wikipedia. https://en.wikipedia.org/wiki/Change_of_variables
[3]Wikipedia contributors. (2023c). Location–scale family. https://en.wikipedia.org/wiki/Location-scale_family
[4]Soch, J. (2020, July 22). Law of the unconscious statistician. The Book of Statistical Proofs. https://statproofbook.github.io/P/mean-lotus.html
[5]Reparameterization Trick - WHY & BUILDING BLOCKS EXPLAINED!