Bayesian Learning - by example

Welcome to my blog! For my first post, I decided that it would be useful to write a short introduction to Bayesian learning, and its relationship with the more traditional optimization-theoretic perspective often used in artificial intelligence and machine learning, presented in a minimally technical fashion. We begin by introducing an example.

Example: binary classification using a fully connected network

First, let’s introduce notation. For simplicity suppose there are no biases, and define the following.

  • $\v{y}_{N\times 1}$: a binary vector where each element is a target data point. $N$ is the amount of input data.
  • $\m{X}_{N\times p}$: a matrix where each row is an input data vector. $p$ is the dimensionality of each input.
  • $\v\beta^{(x)}_{p \times m}$: the matrix that maps the input to the hidden layer. $m$ is the number of hidden units.
  • $\v\beta^{(h)}_{m \times 1}$: the vector that maps the hidden layer to the output.
  • $\sigma$: the network’s activation function, for instance a ReLU function.
  • $\phi$: the softmax function.
$\m{X}$
$\v\beta^{(x)}$
$\v\beta^{(h)}$
$\v{y}$

The standard approach

We begin by defining an optimization problem. Let $\v\beta$ be a $k$-dimensional vector consisting of all values of $\v\beta^{(x)}$ and $\v\beta^{(h)}$ stacked together. Our network’s prediction $\v{\hat{y}} \in [0,1]^N$ is given by

[ \hat{\v{y}} = \phi\del{\sigma\del{\m{X} \v\beta^{(x)}} \v\beta^{(h)}} ]

Now, we proceed to learn the weights. Let $\v{\hat\beta}$ be the learned values for $\v\beta$, let $\vert\vert\cdot\vert\vert$ be the $L^2$ norm, fix some $\lambda \in \R^+$, and set

[ \v{\hat\beta} = \underset{\v\beta}{\arg\min}\cbr{ \sum_{i=1}^N -y_i\ln(\hat{y}_i) - (1-y_i)\ln(1 - \hat{y}_i) + \lambda\vert\vert\v\beta\vert\vert^2} . ]

The expression being minimized is called cross entropy loss1. The loss is differentiable, so we can minimize it by using gradient descent or any other method we wish. Learning takes place by minimizing the loss, and the values we learn – here, $\v{\hat\beta}$ – are a point in $\R^k$.

Why cross-entropy rather than some other mathematical expression? In most treatments of classification, the reasons given are purely intuitive, for instance, it is often said to stabilize the optimization algorithm. More rigorous treatments1 might introduce ideas from information theory. We will provide another explanation.

The Bayesian approach

Let us now define the exact same network, but this time from a Bayesian perspective. We begin by making probabilistic assumptions on our data. Since we have that $\v{y} \in \cbr{0,1}^N$, and since we assume that the order in which $\v{y}$ is presented cannot affect learning – this is formally called exchangeability – there is one and only one distribution that $\v{y}$ can follow: the Bernoulli distribution. The parameter of that distribution is the same expression $\v{\hat{y}}$ as before. Hence, let

[ \v{y} \given \v\beta \dist\f{Ber}\sbr{\phi\del{\sigma\del{\m{X} \v\beta^{(x)}} \v\beta^{(h)}}} . ]

This is called the likelihood: it describes the assumptions we are making about the data $\v{y}$ given the parameters $\v\beta$ – here, that the data is binary and exchangeable. Now, define the prior for $\v\beta$ as

[ \v\beta \dist\f{N}_k\del{0, \frac{\lambda^{-1}}{2}} . ]

This describes our assumptions about $\v\beta$ external to the data – here, we have assumed that all components of $\v\beta$ are a priori independent mean-zero Gaussians. We can combine the prior and likelihood using Bayes’ Rule

[ f(\v\beta \given \v{y}) = \frac{f(\v{y} \given \v\beta) \pi(\v\beta)}{\int_{\R^k} f(\v{y} \given \v\beta) \pi(\v\beta) \dif \beta} \propto f(\v{y} \given \v\beta) \pi(\v\beta) ]

to obtain the posterior $\v\beta \given \v{y}$. This is a probability distribution: it describes what we learned about $\v\beta$ from the data. Learning takes place through the use of Bayes’ Rule, and the values we learn – here, $\v\beta \given \v{y}$ – are a probability distribution on $\R^k$.

Connecting the two approaches

Is there any relationship between $\v{\hat\beta}$ and $\v\beta \given \v{y}$? It turns out, yes – let’s show it. First, let’s write down the posterior

[ f(\v\beta \given \v{y}) \propto f(\v{y} \given \v\beta) \pi(\v\beta) \propto \sbr{\prod_{i=1}^N \hat{y}_i^{y_i} (1 - \hat{y}_i)^{1 - y_i}} \exp\cbr{\frac{\v\beta^T\v\beta}{-\lambda^{-1}}} . ]

Now, let’s take logs and simplify:

[ \ln f(\v\beta \given \v{y}) = \sum_{i=1}^N y_i \ln(\hat{y}_i) + (1-y_i)\ln(1 - \hat{y}_i) - \lambda\vert\vert\v\beta\vert\vert^2 + \f{const} . ]

Having computed that, note that that taking logs and adding constants preserve optima, and consider the posterior mode:

[ \begin{aligned} \underset{\v\beta}{\arg\max}\cbr{f(\v\beta \given \v{y})} &= \underset{\v\beta}{\arg\max}\cbr{\ln f(\v\beta \given \v{y})} = \nonumber \\ &=\underset{\v\beta}{\arg\max}\cbr{ \sum_{i=1}^N y_i \ln(\hat{y}_i) + (1-y_i)\ln(1 - \hat{y}_i) - \lambda\vert\vert\v\beta\vert\vert^2 } = \nonumber \\ &=\underset{\v\beta}{\arg\min}\cbr{ \sum_{i=1}^N -y_i \ln(\hat{y}_i) - (1-y_i)\ln(1 - \hat{y}_i) + \lambda\vert\vert\v\beta\vert\vert^2 } = \nonumber \\ &= \v{\hat{\beta}} . \end{aligned} ]

What have we shown? Minimizing cross-entropy loss is equivalent to maximizing the posterior distribution. The loss function maps to the likelihood, and the regularization term maps to the prior.

What it all means

Why is this useful? It gives us a probabilistic interpretation for learning, which helps us to construct and understand our models. This is especially in more complicated settings: for instance, we might ask, where does $\v{\hat{y}} = \sigma\del{\m{X} \v\beta^{(x)}} \v\beta^{(h)}$ come from? In fact, we can use ideas from Bayesian Nonparametrics to derive $\v{\hat{y}}$ by considering a likelihood on a function space under a ReLU basis expansion2. The network’s loss and architecture can both be explained in a Bayesian way.

There is much more: we could consider drawing samples from the posterior distribution, to quantify uncertainty about how much we learned about $\v\beta$ from the data. Markov Chain Monte Carlo3 methods are the most common class of methods for doing so. We can use ideas from hierarchical Bayesian models to define better regularizers compared to $L^2$ – the Horseshoe4 prior is a popular example. For brevity, I’ll omit further examples – the book Bayesian Data Analysis5 is a good introduction, though it largely focuses on methods of interest mainly to statisticians.

How general is this perspective? Very: an abstract result called Cox’s Theorem states, in modern terms, that every true-false logic under uncertainty is isomorphic to conditional probability. This means that all learning formalizable in the above sense is Bayesian. So, if you can’t represent a given method in a Bayesian way, I would be rather worried. For a formal statement and details, see my preprint6 on the subject.

At the end of the day, having many different mathematical perspectives enables us to better understand how learning works, because things that are not obvious from one perspective might be easy to see from another. Whereas the optimization-theoretic approach we began with did not give a clear reason for why we should use cross-entropy loss, from a Bayesian point of view it follows directly out of the binary nature of the data. Sometimes, the Bayesian approach has little to say about a particular problem, other times it has a lot. It is useful to know how to use it when the need arises, and I hope this short example has given at least one reason to read about Bayesian statistics in more detail.

References

  1. See Chapter 5 of Deep Learning7 2

  2. See Chapter 20 of Bayesian Data Analysis5

  3. See Chapter 11 of Bayesian Data Analysis5, but note that MCMC methods are far more general than presented there. An article8 by P. Diaconis gives a rather different overview. 

  4. C. M. Carvalho, N. G. Polson, and J. G. Scott. The Horseshoe estimator for sparse signals. Biometrika, 97(2):1–26, 2010. 

  5. A. Gelman, J. B. Carlin, H. S. Stern, D. B. Dunson, A. Vehtari, and D. B. Rubin. Bayesian Data Analysis. 2013.  2 3

  6. A. Terenin and D. Draper. Cox’s Theorem and the Jaynesian Interpretation of Probability. arXiv:1507.06597, 2015. 

  7. I. Goodfellow, Y. Bengio, A. Courville. Deep Learning. 2016. 

  8. P. Diaconis. The Markov Chain Monte Carlo revolution. Bulletin of the American Mathematical Society, 46(2):179–205, 2009.