Introduction

A major shortcoming of modern neural networks is their inability to perform continual learning. This is largely due to their susceptibility to catastrophic forgetting. Catastrophic forgetting refers to the phenomenon in which any significant alteration to an already trained neural network’s weights leads to a ‘catastrophic loss’ of what has been learned.

This poses a fundamental constraint on the way that neural networks can be trained and deployed: At training time, we need to expose neural networks to the entire training data distribution, and once deployed we can’t incrementally teach them more new data. We can only do so by recovering all of the original data in addition to the new data, and essentially re-training the network.

This seems like a serious problem, but the current paradigm of deep learning allows us to sidestep it by preemptively collecting all the data we might need. Then, a single optimization job is run on the large dataset. This works extremely well in practice, making the argument for sequential or incremental learning, weaker. For now, the cost of training on the entire data distribution at once isn’t all that high.

However, looking forward, it’s clear that catastrophic forgetting is indeed headed for catastrophic costs. For example, let’s imagine a future where neural network agents go out into the world and continuously collect new data to learn from. As a neural network consumes more and more data, even tiny, incremental learning tasks will require reproducing all the training data the agent has seen in its lifetime, making learning anything new increasingly expensive.

As an aside, we can think about how the constant pressures of survival may have caused nature to evolve brains that could learn incrementally without shutting down for inordinate amounts of time. One could argue that artificial intelligence doesn’t necessarily have this constraint. However, unlike now, we will likely have two very different types of artificial intelligence agents in the near future: Ones that can be trained once on a year, say, on a massive dataset, whose learned network stays essentially static, being employed as an expert system for specific tasks, and other, more general agents that go out into the world, exploring the world and learning about it in an online manner. For the latter, it will be crucial that we solve the problem of continual learning.

A Definition of Continual Learning

Let’s begin with a definition of continual learning. The authors define continual learning with a set of 5 desiderata:

1. No Catastrophic Forgetting : The agent should perform reasonably well on previously learnt tasks.
2. Positive Forward Transfer: The agent should learn new tasks while taking advantage of knowledge extracted from previous tasks.
3. Scalability: The agent should be trainable on a large number of tasks.
4. Positive backwards transfer: The agent should be able to gain improved performance on previous tasks after learning a new task which is similar or relevant.
5. Learn without requiring task labels: The agent should be applicable in the absence of clear task boundaries.

As we will see later, the proposed model in the paper meets, to some extent, all of these desiderata.

Previous Approaches

Two major previous works are relevant to this paper: The proposed model actually combines these two works into a more scalable and robust design.

EWC (Elastic Weight Consolidation) (Kirkpatrick et al., 2017)

Elastic Weight Consolidation was proposed to address catastrophic forgetting. It can be seen as a regularization method which uses a Bayesian perspective on the network parameters to define a regularization term that prevents the parameters from diverging too greatly from the optimal values for previous tasks. EWC was shown to be quite effective at preventing catastrophic forgetting, but because a regularization term is added per each new task, the network’s ability to learn new tasks can be diminished.

Figure 1. EWC

The operation of EWC is as follows:

2. Use the optimized parameters to define a prior on future parameters.
4. Optimize the parameters on a modified objective which has a regularization term which will try to keep parameters closer to an ideal setting for previous tasks.

Progressive Networks (Rusu et al., 2016)

Progressive networks take a different approach. Rather than regularizing and ‘protecting’ the weights learned on previous tasks, they overcome catastrophic forgetting by growing the network capacity and creating weights for new tasks. It was shown that this method works very well for not forgetting, but is extremely costly due to the network size scaling quadratically in the number of tasks.

Figure 2. Progressive Networks

The operation of Progessive Networks is as follows:

3. Move to a second task
4. First column is frozen and a new column with new parameters is instantiated
5. Layer ${h^{(2)}_i}$ receives input from both ${h^{(2)}_{i-1}}$ and ${h^{(1)}_{i-1}}$ via lateral connections.

Elastic Weight Consolidation in depth

Because EWC is central to the paper (the authors propose and use an improved version of EWC) and the mathematical justification for it is somewhat involved, the next section will go over the derivation of EWC in detail. Note that most of the explanation is referenced from Ferenc’s blogpost, which I highly recommend.

The key idea of EWC is to maintain a Bayesian posterior distribution $\log p(\theta|D_A,D_B...D_{k-1})$. For task $k$, we can use this posterior from previous tasks as a prior for the current task $k$. The prior acts as a regularizer, keeping $\theta$ in a region such that catastrophic forgetting of tasks $latex D_A,D_B,…D_{k-1}$ will be minimized. This idea is illustrated in Figure 1.

Let’s walk through the example where we have two tasks, A and B. The posterior after having seen $D_A$ and $D_B$ can be decomposed as follows:

$\log p(\theta|D_A,D_B) = \log p(D_B|\theta) + \log p(\theta|D_A) - \log p(D_B|D_A)$

The first term on the right, $\log p(D_{B}|\theta)$ is the log likelihood, i.e. the objective function for task B, and tractable. The second term on the right, $\log p(\theta | D_{A})$ is the posterior of $\theta$ given data from task A, and is intractable. The third term is constant w.r.t $\theta$ and doesn’t matter. This leaves us with the task of approximating $\log p(\theta | D_{A})$.

It turns out that we can do this by the method of Laplace approximation. Let’s go over the steps briefly. Recall that we can find the mode $\theta^*_A$ of the posterior via optimization:

$\theta^*_A = \textrm{argmin}_{\theta} \{-\log p(\theta|D_A)\}$

We know that the gradient of $-\log p(\theta | D_{A})$ w.r.t $\theta$ is 0 at ${{\theta}^*}_{A}$. Then, local approximation of the posterior becomes possible, using a 2nd order Taylor expansion around ${{\theta}^*}_{A}$

$-\log p(\theta|D_A) \approx \frac{1}{2}(\theta-\theta^*_A)^\top H(\theta^*_A)(\theta-\theta^*_A)+\textrm{constant}$

Here, $H({{\theta}^*}_{A})$ is the Hessian of the posterior $log p(\theta | D_{A})$. The Hessian itself is approximated using the Fisher matrix, using the property that the Fisher is equivalent to the second derivative of the loss near a minimum (Pascanu and Bengio, 2013):

$H(\theta^*_A) \approx N_A\cdot F(\theta^*_A)+H_\textrm{prior}(\theta^*_A)$

$N_A$ is the number of i. i. d. observations in $D_A$ and $F({{\theta}^*}_{A})$ is the Fisher information matrix on task A. The second term on the right can be ignored because the paper doesn’t use this prior on $\theta$. The Fisher matrix can be computed from first-order derivatives alone (much less expensive)  and is thus easy to calculate even for large models.

Finally, we have

$\log p(\theta|D_A,D_B) \approx \log p(D_B|\theta) - \frac{1}{2}\sum\limits_{i}(N_A F_{A,i}+\lambda_\textrm{prior})(\theta_i-\theta^*_{A,i})^2+\textrm{constant}$

The approximation to the log posterior used here is known as Laplace’s method, and it is equivalent to Gaussian approximation to the posterior around its mode.

‘Diagonalized Laplacian Approximation’

• Because the parameter space is high dimensional, EWC makes a further diagonal approximation of F, treating its off-diagonal entries as 0.
• These diagonal Fisher information values can be computed easily via back-propagation with minimal change to the stochastic gradient descent algorithm used to find the optimum ${{\theta}^*}_{A}$

Finally, we get the EWC optimization objective, which lets us find the mode $\theta^*_B$ via gradient descent.

$\theta^*_B = \textrm{argmin}_{\theta} \Bigg\{ -\log p(D_B|\theta)+\frac{1}{2}\sum\limits_{i}(\lambda_A F_{A,i}+\lambda_\textrm{prior})(\theta_i-\theta^*_{A,i})^2\Bigg\}$

Progress & Compress Framework

The paper proposes to tie these models together. It brings together the various elements from Progressive Networks and EWC, and also knowledge distillation to create a model that learns a new task (Progressive Networks), compresses the learned knowledge (Knowledge Distillation) while preventing catastrophic forgetting (EWC). They name the model Progress & Compress.

The Progress & Compress framework is quite simple: It’s simply two neural networks which alternately work to learn a new task then compress the learned knowledge into a knowledge base. The two components are called the Active Column and the Knowledge Base, respectively.

The training happens in two phases:

• Progress phase

Let’s walk through a Progress phase. First, a new task/problem is presented. We will freeze the parameters of the knowledge base, only optimizing parameters in the active column. Like Progressive networks, layerwise connections between the knowledge base and the active column are added to enable the reuse of features encoded in the knowledge base, thus enabling positive transfer from previously learnt tasks. Thus, the active column freely learns the new task.

Here, the lateral adaptors are implemented as MLPs, and the $i^{th}$ layer of the active column is defined as follows:

$h_i = \sigma(W_ih_{i-1}+\alpha_i\odot U_i\sigma(V_i h^{KB}_{i-1}+c_i)+b_i)$

$b_{i}$ and $c_{i}$ are biases, $a_{i}$ is a trainable vector of size equal to the number of units in layer i, $W_{i}$ , $U_{i}$ , $V_{i}$ are weight matrices.

• Compress phase

The goal of the Compress phase is to distill the learned AC into the KB. This time, we’ll fix the parameters of the AC instead. The compression optimization objective is a distillation loss with the EWC penalty. EWC protects the knowledge base against catastrophic forgetting such that all previously learnt skills are maintained.

$\mathbb{E}\Big[ \textrm{KL}(\pi_k(\cdot|x)) \| \pi^{\textrm{KB}}(\cdot|x)) \Big]+\frac{1}{2} \| \theta^{\textrm{KB}}-\theta^{\textrm{KB}}_{k-1} \|^2_{\gamma F^*_{k-1}}$

$\theta^{KB}$ is optimized. $\pi_{k}(\cdot|x)$ is policy of AC after learning task k. $\pi^{KB}(\cdot|x)$ is policy of KB

• Similarity with the brain’s sleep/wake cycle

The paper also notes that a similar process happens in the human brain. Recent studies have found sleep as a brain state optimizing memory consolidation. This is in contrast to the waking brain being optimized for encoding of memories. Consolidation originates from the reactivation of recently encoded neuronal memory representations. The idea that the offline consolidation of memory during sleep represents a principle of long-term memory formation established in quite different physiological systems.

Online EWC

• Proposed in Huszar 2017
• Let $T_{1:k}$ = ($T_1$, $T_2$, . . . , $T_k$) be the data of a sequence of k tasks
• Posterior of $\theta$ is

$p(\theta|\mathcal{T}_{1:k}) \propto p(\theta)\prod\limits_{i=1}^{k} p(\mathcal{T_i}|\theta) \propto p(\theta|\mathcal{T}_{1:k-1}) p(\mathcal{T}_k|\theta)$

• We approximate $p(T_k|\theta)$ for each task

$p(\mathcal{T}_i|\theta) \approx \mathcal{N}(\theta;\theta^*_i,F^{-1}_i),$

• Recall that the loss was

$-\log p(\mathcal{T}_i|\theta) + \frac{1}{2}\sum\limits_{j=0}^{i-1}\| \theta-\theta^*_j \|^2_{F_j}$

• Requires a mean and a Fisher to be kept for each task
• Linear computational cost w.r.t the number of tasks
• Proposed: Just approximate (4) directly

$-\log p(\mathcal{T}_i|\theta) + \frac{1}{2}\| \theta-\theta^*_{i-1} \|^2_{\sum_{j=0}^{i-1}F_j}$

Experiments

• Omniglot handwritten characters
• 6 games in the Atari suite
• Distributed actor critic for rl experiments – learn both a policy and  value function from raw pixels with a shared conv. Encoder
• Online EWC is applied to all P&C
• Simple baseline of ‘Finetuning’ – Standard training without protection against catastrophic forgetting

Resilience against catastrophic forgetting

• 1 alphabet at a time
• P&C + online EWC is slightly worse than EWC & regular EWC

Assessing Forward Transfer

• Learn 7 mazes, test on 1 held-out maze
• Similarity between tasks is high
• All methods show positive transfer but P&C shows most