Deep Learning with function spaces

Deep learning is perhaps the single most important breakthrough in statistics, machine learning, and artificial intelligence that has been popularized in recent years. It has allows us to classify images - for decades a challenging problem - with nowadays usually better-than-human accuracy. It has solved Computer Go, which for decades was the classical example of a board game that was exceedingly difficult for computers to play. But what exactly is deep learning?

Many popular explanations involve analogies with the human brain, where deep learning models are interpreted as complex networks of neurons interacting with one another. These perspectives are useful, but they’re not math: just because deep learning models mimic the brain, doesn’t mean they provably work. This post will highlight some ideas that may be helpful in moving toward an understanding of why deep learning works, presented at an intuitive level. The focus will be on high-level concepts, omitting algebraic details such as the precise form of tensor products.

The Function Space Perspective

The key idea of this post is that to understand why deep learning works, we should not work with the network directly. Instead, we will define a model for learning on a space of functions, truncate that model, and obtain deep learning.

Consider the model

[ \hat{\v{y}} = f(\m{X}) ]

where the goal is to learn the function $f$ that maps data $\m{X}$ to the predicted value $\hat{\v{y}}$. But wait, how do we go about learning a function? Let’s first consider a single-variable function $f(x): \R \goesto \R$ and recall that any function may be written as an infinite sum with respect to a location-scale basis, i.e. we have for an appropriately defined function $\sigma$ that

[ f(x) = \sum_{k=1}^\infty a_k \, \sigma(b_k x + c_k) + d_k . ]

What’s happening here? We’re taking the function $\sigma$, shifting it left-right by $b_k$, stretching it by a combination of $a_k$ and $c_k$, and shifting it up-down by $d_k$. As long as $\sigma$ is sufficiently rich to form a basis on $\R$, if we add up infinitely many of them, we can approximate $f$ to any precision we want. To make learning possible, let’s truncate the sum, so that we sum $K$ elements instead of $\infty$, and get

[ f(x) = \sum_{k=1}^K a_k \, \sigma(b_k x + c_k) + d_k . ]

We now have a finite set of parameters, so given a data set $(\m{X},\v{y})$, we can define a probability distribution for $\v{y}$ under the predicted values $\hat{\v{y}}$, and learn the coefficients using Bayes’ Rule.

But wait: the expressions we get by following this procedure, extended to matrices and vectors, are exactly those given by a 1-layer fully connected network. This is what a fully connected network does, and this is why it works: we are expanding an arbitrary function with respect to a basis, and learning the coefficients of the expansion using Bayes’ Rule1. That’s it!

Going Deep

With the above perspective in mind, let’s consider deep learning. We’re going to apply another trick: rather than learning $f$ directly, let’s instead define functions $f^{(1)},f^{(2)},f^{(3)}$ such that

[ \hat{\v{y}} = f(\m{X}) = f^{(1)}\cbr{f^{(2)}\sbr{f^{(3)}\del{\m{X}}}} ]

It’s not obvious why we should do this, but let’s go with it for now. Then, let $\sigma$ be the ReLU function, and expand $f^{(3)}$ with respect to that basis, just as we did above, but with matrix-vector notation, to get

[ \hat{\v{y}} = f^{(1)}\cbr{f^{(2)}\sbr{ \v{a}^{(3)} \sigma\del{\m{X}\v{b}^{(3)} + \v{c}^{(3)}} + \v{d}^{(3)} }} . ]

Now, let’s expand $f^{(2)}$, yielding

[ \hat{\v{y}} = f^{(1)}\cbr{\v{a}^{(2)}\sigma\sbr{\del{\v{a}^{(3)} \sigma\del{\m{X}\v{b}^{(3)} + \v{c}^{(3)}} + \v{d}^{(3)}}\v{b}^{(2)} + \v{c}^{(2)}} + \v{d}^{(2)}} . ]

Notice that we can set $\v{b}^{(2)} = \v{1}$ and $\v{c}^{(2)} = \v{0}$ with no loss of generality to slightly simplify our expression. Upon expanding $f^{(1)}$, we are left with

[ \hat{\v{y}} = \v{a}^{(1)}\sigma\cbr{\v{a}^{(2)}\sigma\sbr{\v{a}^{(3)} \sigma\del{\m{X}\v{b}^{(3)} + \v{c}^{(3)}} + \v{d}^{(3)}} + \v{d}^{(2)}} + \v{d}^{(1)} ]

which is exactly the expression for a 3-layer fully connected network.

So, what is deep learning? Deep learning is a model that learns a function $f$ by splitting it up into a sequence of functions $f^{(1)},f^{(2)},f^{(3)},..$, performing a ReLU basis expansion on each one, truncating it, and learning the remaining coefficients using Bayes’ Rule.

Example: why Residual Networks work

This perspective can be used to understand recently popularized technique in deep learning. For illustrative purposes, let’s consider a 3-layer residual network. Suppose $\m{X}$ is of the same dimensionality as the network. A residual network is a model of the form

[ \begin{aligned} \hat{\v{y}} = f(\m{X}) = &f^{(1)}\cbr{f^{(2)}\sbr{f^{(3)}\del{\m{X}} + \m{X}} + \sbr{f^{(3)}\del{\m{X}} + \m{X}}} \nonumber \\ &+ \cbr{f^{(2)}\sbr{f^{(3)}\del{\m{X}} + \m{X}} + \sbr{f^{(3)}\del{\m{X}} + \m{X}}} . \end{aligned} ]

So, why do residual networks perform better? Consider the above from a Bayesian learning the point of view: we start with a prior distribution - determined uniquely by the regularization term - and end with a posterior distribution that describes what we learned. Suppose that nothing is learned in the 3rd layer. Then the posterior distribution must be the same as the prior. With $L^2$ regularization, this means that the posterior mode of the coefficients of the basis expansion of $f^{(3)}$ will be zero. Hence,

[ f^{(3)}(x) = \sum_{k=1}^K 0 \, \sigma(0 \times x + 0) + 0 = 0 ]

and the model collapses to

[ \hat{\v{y}} = f(\m{X}) = f^{(1)}\cbr{f^{(2)}\sbr{\m{X}} + \m{X}} + \cbr{f^{(2)}\sbr{\m{X}} + \m{X}} . ]

Contrast this with a non-residual network, which collapses to

[ \hat{\v{y}} = f(\m{X}) = f^{(1)}\cbr{f^{(2)}\sbr{\v{0}}} = \text{constant} . ]

In reality, of course, the network learns something in deeper layers, so behavior isn’t quite this bad. But, if we suppose that deeper layers learn less and less given the same data, the model must eventually stop working if we keep adding layers. Thus, standard networks don’t work if we make them too deep. Residual networks fix the problem.

What have we gained from this perspective?

Thinking about function spaces can make deep learning substantially more understandable. Instead of thinking about networks, which are complicated, we can think about functions, which are in my view simpler.

The ideas above can for instance be used to understand what convolutional networks do: they make assumptions on how each $f^{(i)}$ behaves over space. Similarly, we can see why ReLU2 units might perform slightly better than sigmoid units: because they are unbounded, less of them may be required to approximate a given function well.

Part of what makes functions simpler is that it is easy to visualize what scaling and shifting does to them. For example, it is easy to see that switching from ReLU to Leaky ReLU3 units is the same as increasing the bias term in the basis expansion. It’s certainly possible that this may sometimes be helpful, but it would be a big surprise to me if doing this resulted in substantially better performance across the board.

One major question that the function space perspective raises is why learning $f^{(1)}, f^{(2)}, f^{(3)},..$ separately is so much easier than learning $f$ directly. I don’t know of a good answer to this question.

A key benefit of thinking with function spaces is that it gives us a principled way to derive the expressions needed to define and train networks. The residual networks presented here differ slightly from the original work in which they were presented4 – more recent work has proposed precisely the formulas derived here5 which were found to improve performance.

I’m not sure why deep learning is not typically presented in this way – the function space perspective is largely omitted from the classical text Deep Learning6. Overall, I hope that this short introduction has been useful for understanding deep learning and making the structure present in the models more transparent.

References

  1. See Chapter 20 of Bayesian Data Analysis7

  2. R Hahnloser, R. Sarpeshkar, M. A. Mahowald, R. J. Douglas, H. S. Seung (2000). Digital selection and analogue amplification coexist in a cortex-inspired silicon circuit. Nature 405(6789), 2000. 

  3. A. L. Maas, A. Y. Hannun, A. Y. Ng. Rectifier Nonlinearities Improve Neural Network Acoustic Models. ICML 30(1), 2013. 

  4. K. He, X. Zhang, S. Ren, and J. Sun. Deep Residual Learning for Image Recognition. CVPR 28(1), 2015. 

  5. K. He, X. Zhang, S. Ren, and J. Sun. Identity Mappings in Deep Residual Networks. ECCV 14(1), 2016. 

  6. See Chapter 6 of Deep Learning8

  7. A. Gelman, J. B. Carlin, H. S. Stern, D. B. Dunson, A. Vehtari, and D. B. Rubin. Bayesian Data Analysis. 2013. 

  8. I. Goodfellow, Y. Bengio, A. Courville. Deep Learning. 2016.