05 Feb 2018

Recently I stumbled across Decoupled Neural Interfaces using Synthetic Gradients which aims to solve one of the biggest slowdowns in Deep Learning: the propagation step (In terms of the training, it’s probably the only real step - things like moving data between GPUs and such aren’t often discussed from an algorithmic point of view). This is especially useful because if we can sidestep the issue of needing to complete a forward and backward pass on the data, we can dramatically speed up training. This becomes especially useful when doing training for a single large neural network on multiple GPUs, or on some distributed cluster machine.

I’m going to assume you’re someone who is familiar with Neural Networks, so at times I may use words like “obviously”, or “we can see”; if these terms don’t apply to you, don’t be afraid to message me, I’ll try to redirect you to resources that make it so that what I say becomes obvious in the future :) It’s all about learning, I promise.

# Foreword

The first time I heard about this paper I was absolutely flabbergasted - the idea that you could estimate your gradients seems like cheating to me - if you can guess your error rate why would you even bother with actually propagating the data through the network. If you feel the same way as I did, hopefully the following post will put your mind at ease by providing you with an intuitive reason as to why it works WITHOUT appealing to authority in that ‘the paper is probably right because the authors are smart and know what they’re doing’. I’ll try to ask the right questions and poke at things so that we both come away richer from this deep dive into Synthetic Gradients.

# Brief Background:

Deep Learning consists of 2 main phases:

1) $\textbf{Forward Propagation}$

We start from the input layer and do matrix operations + apply an activation function until we reach the last layer where we get some error that describes how incorrect our parameters are.

2) $\textbf{Backwards Propagation}$

The error calculated above describes how to approach the problem of reducing the error using gradient descent. We tweak those parameters using the gradients with the hope that our reduction is generalized enough so that it gets a low error on some unseen test set.

# 0) Abstract

By using Decoupled Neural Interfaces, they can estimate

2) inputs

for arbitrary layers of a neural network (or even blocks of layers) which allows them to asynchronously train the network’s layers / blocks depending on the degree of granularity you want. The strength of this comes from the fact that the estimation is local in that their estimator uses only local information.

Throughout the paper I probably conflate Synthetic Gradients and Decoupled Neural Interfaces but just remember that Synthetic Gradients are the estimated gradients, and DNIs are the structure that allow you to update layers / blocks without waiting for the next / previous layer

## Intuition:

By training a second model to produce some estimate of the gradient, or the input, they don’t have to wait for the layer’s ‘turn’ in the passes. In the training process of this estimator, we generate some estimate and then use the TRUE gradient, or TRUE input to reduce the error, thus creating a mini-optimization step within our entire Neural Network.

Why does this work? I THINK that the reason why this works is because of the following: there are an infinite number of points that can exist on a very high dimensional surface, the surface on which we assume out data lies on or near to. By training our estimator, we are getting some boundary over which the points live thus fine tuning our general ‘idea’ of what the area over which our points live. If our estimate is good enough, what we then have is some set if points over which we can sample. Because in our final prediction we care about how our points are transformed from the input all the way to the output, if we just trained on the general transformed points we should STILL have some good idea of what the output should be.

Manifold assumption: The (high-dimensional) data lie (roughly) on a low-dimensional manifold

The Manifold assumption was the first thing I thought of as well as how data in the same general region get transformed in the same general way but I might be oversimplifying things.

# 1) Introduction

The authors’ intution came from the idea of message passing:

Module: neural network layer

Message: Activation

In this case, we see that we process the message before passing it on to another Module, and in this sense the transformation is local to the current interface (another block isn’t involved)

To update the parameters, $\theta_i$ for Module$_i$ approximate the equation that backpropagation would give them

where h is the activation (makes sense since we mentioned before that it’s local to the current ‘message’). x is the input, y is the ‘target’, and L is the loss but they’re not super-duper important here.

# 2) Discussion and Conclusion

By using their method, it trains faster (makes sense, since that was their goal). They also link to Understanding Synthetic Gradients and Decoupled Neural Interfaces that goes deeper into the analysis, and theoretical understanding of DNIs, and synthetic gradients where they also show convergence properties. After I’m done with this whole thing there’s a good chance I’ll go back and add a follow-up post discussing that paper.

# 3) Experiments

## 3.1 RNNs and Language Modelling

Unfortunately, I don’t feel confident enough with the whole subfield to feel comfortable delving into it but you should read the actual paper (page 5).

## 3.2 Multi-Network systems

They demonstrate how useful this system is by experimenting on two RNNs: where B is executed at a slower rate than A, AND must use communication from A to complete its task. Since the two networks are dependent, this introduces a slowdown in terms of the training.

However, by using Synthetic Gradients they show that Network A learns much quicker than when being locked to Network B, and they show that Network B also learns faster (likely due to the fact that A learns faster)

## 3.3 Feed-Forward Networks

They test the concept to the extreme in a few ways:

1) $\textbf{Asynchronous Updates}$

They randomly update layers in a network with some probability p$_{\text{update}}$ in that a layer is only updated after its forward pass p$_{\text{update}}$ of the time. This is a method that SHOULD break backprop however using Synthetic Gradients and the decoupled interface, this doesn’t constrain updates to only the current layer as the synthetic gradient is used.

2) $\textbf{Complete unlock}$

They remove forward locking as well (they estimate the input to the current layer using the output from the previous layer). This is achieved much in the same manner as the backward locking is removed. Amazingly it still reaches more or less the same performance as with just the synthetic gradient propagation process (albeit slightly slower)

# 4) Decoupled Neural Interfaces

We now approach the actual equations for the estimator, as well as a formulation for the overall model for our main neural network.

1) $\textbf{Notation}$

i) Assume a feedforward network consisting of N layers, such that

ii) $\mathcal{F}_1^N$, the entire forward propagation, where the subscript denotes the start, and the superscript represents the last layer.

iii) $L = L_N$ is the loss imposed on the output of the network

iv) $\alpha$ the learning rate

2) $\textbf{Standard update}$

Each layer $f_i \text{ has parameters } \theta_i$ that can be trained jointly to minimize $L(h_N)$

The reliance on $\delta_i$ means the update to layer i can only occur after layers i + 1 to N have executed a forward pass, then generated some loss L(h$_N$), and backpropagated the error.

Thus, layer i is locked to $\mathcal{F}_{i+1}^N$

3) $\textbf{Modified update}$

Let M$_{i+1}$ describe the estimator as described earlier. So, M$_{i+1}(h_i)$ = $\hat{\delta}_i$

NOTE: we’re propagating the error from the current estimate all the way backwards - it’s easy to slip up if you’re not focusing.

After a full fprop and bprop, we fit the synthetic gradients to the true gradients, minimizing

4) $\textbf{Contextual Synthetic Gradients}$

If we have some supervision (a prior, or context, c) we can modify our estimator to take in and use this target

M$_{i+1}(h_i, c)$ = $\hat{\delta}_i$

NB: In the paper they also discuss synthetic gradients for Recurrent Networks which I won’t do (they actuall discuss it first and say that it’s easier but I disagree)

# 5) Supplementary Material Discussion

## 5.1) Unified View of Synthetic Gradients

The paper discusses BP(0), BP($\lambda$), and Recurrent BP($\lambda$) but all of which require some knowledge of TD($\lambda$) and IMO aren’t essential for the objective of this post - to gain an insight into what is happening, why they do what they do and why the results are as is.

## 5.2) Are Synthetic Gradients Sufficient?

Without jumping into the followup paper, we’re going to talk about why the authors say that the synthetic gradients are sufficient.

The basic intuition behind it is that in SGD, we’re updating parameters according to some sampling of the expected loss gradient (assume the loss gradient is drawn from some distribution considering the inputs, and the output labels)

The authors draw parallels between this paper and an actor-critic architecture where the Neural Network is the actor (as usual), and the Synthetic gradient is the critic.

# 6) Afterword

I had an important discussion with a good friend today from CMU about the importance of both understanding the applications of things, as well as understanding the math behind what is happening. I personally hope that in this blog post I have been able to give you a basic idea of the two of them but there are portions over which I have skimmed the details, or just ommitted altogether.

It is my hope that this post will explain things in a very intuitive manner so that when you read DeepMind’s own blogpost, or the papers themselves that you have a firm grounding in what to look for. You should read the papers, then try to understand the math and maybe implement the code if you can to test your understanding :)