본 포스트에서는 2018년 9월 MICCAI에서 발표 될 논문 Keep and Learn: Continual Learning by Constraining the Latent Space for Knowledge Preservation in Neural Networks에 대해 소개하려고 합니다. 여러 사이트에 privacy-critical data가 존재하여 model이 각 사이트를 돌아다니며 incrementally 학습되어야 하는 경우, 과거 사이트에서 배운 knowledge가 서서히 잊혀지는 catastrophic forgetting 현상이 발생합니다. 이 연구에서는 knowledge preservation에 유리하도록 feature space를 모델링하고, 해당 space에 feature vector를 한정하는 방식으로 model을 학습하여 catastrophic forgetting을 극복하는 방법을 다룹니다.


Introduction

제한적인 multi-center learning 환경 (즉, 각 center의 data는 오직 해당 center에서만 활용가능한 환경)에서는 모든 center의 데이터를 학습에 한 번에 활용할 수 없기 때문에 (아래 그림의 왼쪽), model이 각 center를 순차적으로 돌아다니며 점진적으로 학습되어야 합니다 (아래 그림의 오른쪽).

environment.png

그러나, 한 stage의 data가 다음 stage에서 활용될 수 없는 환경에서의 순차적 학습 (continual learning)의 경우, 이 전 stage에서 학습된 knowledge가 서서히 잊혀지게 되는 catastrophic forgetting 현상이 발생하게 됩니다. 특히, gradient descent 알고리즘으로 optimization을 수행하는 neural network 에서는 이 문제가 더욱 심각하다고 알려져 있습니다 [1].


Baseline Models

앞서 언급한 catastrophic forgetting 문제는 딥러닝 분야의 중요한 연구주제 중 하나 입니다. 가장 naive한 방법으로는 Fine-Tuning (FT) [2]이 있는데요, 이 전 stage에서 학습이 완료된 model parameters를 initial point로 하여 현재 stage의 학습 data로 model을 tuning해 나가는 방식 입니다 (아래 그림).

FT.png

두 번째로, 딥마인드에서 발표한 Elastic Weight Consolidation (EWC) [3]이 있는데요, 이 전 stage에서 학습된 model parameters의 중요도를 정의하고 (as a Fisher information matrix), 이를 per-parameter weight decay constant로 활용하여 이 전 stage model 관점에서 중요한 model parameters가 다음 stage 학습에서 많이 변하지 못하도록 regularization 하는 기법을 제안했습니다 (아래 그림).

EWC.png

EWC의 loss function은 task solving loss  L_n (\theta) (e.g., classification의 경우 cross-entropy)와 EWC-loss  L_{EWC} (\theta) 로 구성이 됩니다 ( \theta 는 model parameters). EWC-loss는 아래의 수식에서 알 수 있듯이, 이 전 stage에서 학습된 model parameters  \theta^* 를 기준으로 현재 stage에서 학습되어야 하는 model parameters  \theta 를 per-parameter importance  F_j ( j-th element of the Fisher information matrix) 를 활용하여 constraining 하는 loss 입니다.

loss_EWC.png

세 번째로, Learning without Forgetting (LwF) [4]이 있습니다. 아래의 그림과 같이, 각 stage의 학습을 시작하기 전에 현재 stage의 모든 training examples에 대해서 이 전 stage에서 학습이 완료된 model의 feed-forward logit (LwF-logit)을 미리 계산하고, 각 example의 label과 LwF-logit을 모두 이 번 stage 학습에 활용합니다. 이 때, label은 새로운 knowledge를 학습하는 목적으로 사용하고, LwF-logit은 과거의 knowledge를 보존하는데 사용합니다. 예를들어  K-th stage 학습은  K 개의 old branches  Y_o 관점에서의 optimal feature space  Z 를 찾는 것으로 이해할 수 있습니다.

LwF.png

LwF의 loss function은 task solving loss  L_n (\theta) 와 LwF-loss  L_{LwF} (\theta) 로 구성이 됩니다. LwF-loss는 아래의 수식에서처럼, old branches  Y_o 에 대한 LwF-logit을 label로 하는 loss 입니다.

loss_LwF.png

Original LwF는 multi-task multi-center learning을 위해 제안된 방식으로, single-task multi-center learning에서는 stage가 증가함에따라 optimal feature space  Z 를 찾기 어려워지는 단점이 있습니다. 이러한 단점을 single-task multi-center learning에 맞게 보완한 model이 아래의 LwF+ 입니다. 그림과 같이 현재 stage의 old branch를 위한 LwF-logit은 이 전 stage의 old branch로부터 계산하여 사용하는 방식입니다 (각 stage별 task가 동일한 환경이기 때문에 가능한 방식).

LwF-p.png

아래의 수식과 같이,  L_{LwF+} (\theta) 는 단일 old branch  Y_o 에 대한 loss로만 구성됩니다.

loss_LwF-p.png

앞 서 언급한 EWC와 LwF (혹은 LwF+)는 neural network에서 catastrophic forgetting을 극복할 수 있는 대표적인 방법들이며, 특히 LwF는 output activation을 한정하는 방식으로, EWC는 model parameters를 한정하는 방식으로 catastrophic forgetting을 극복합니다. 즉, 두 방법은 서로 complementary하며, 따라서 아래와 같이 동시에 활용될 수 있습니다 (EWCLwF 와 EWCLwF+).

loss_EWCLwF.png

loss_EWCLwF-p.png


Proposed Method

제안하는 방식은 첫 번째 stage 학습과 나머지 stages 학습으로 나뉘어 집니다 (아래 그림의 왼쪽: 1st stage, 오른쪽: following stages).

[1st stage] Feature extractor  f 와 classifier  g 를 학습하는 동안, latent vector  z \in Z 와 its reconstruction  h(g(z)) 사이의 L2-distance를 minimize함으로써  g 의 inverse function  h 를 approximately modeling할 수 있습니다.

[From the 2nd stage] 1st stage에서 학습된  g of  \theta_n 과  h of  \theta_r 을 현재 stage의  g', h' 으로 했을 때,  g', h' 을 고정하고 old branch output  Y_o 를 LwF-logit으로 한정함으로써, shared feature space  Z 가 학습과정 중에 reconstruction될 수 있습니다.

proposed_conceptual.png

이 과정을 이론적으로 해석해보면 다음과 같습니다. Auto-encoder 구조 하에  Z 의 reconstruction error를 minimize 하는 것은 conditional entropy  H(Z|Y_n) 를 minimize 하는 것과 같습니다 [5]. 또한, task solving loss  L_n (\theta) 를 minimize 하는 것은  H(Z) 가 작아지지 않도록 해줍니다. Random variables  Z, Y_n 사이의 mutual information 은 다음과 같이 정의됩니다:  I(Z; Y_n) = H(Z) - H(Z|Y_n) . 즉, 두 loss function의 joint learning을 통해  Z 는  Y_n 과 서로 mutually informative한 space로 한정되게 됩니다.

제안하는 방식을 영상 분류 문제에 적용하기 위해 ResNet 구조에맞게 구조를 변경하였습니다. ResNet의 top layers는 아래 그림과 같이 average-pooling layer와 fully-connected layer의 조합으로 구성됩니다:  y_{1d} = g_{\theta_{fc}} (avgpool (z_{3d})) (아래 그림의 왼쪽). 여기서  g 와  avgpool 은 서로 commutative하기 때문에 ( avgpool is a linear operation),  y_{1d} = avgpool (g_{\theta_{conv_{1\times1}}} (z_{3d})) 와 equivalent 합니다 (아래 그림의 오른쪽). 실제 실험에서는 approximate inverse function  h 를 좀 더 정확하게 modeling하기 위해, original ResNet 구조(왼쪽)가 아닌 변형된 ResNet 구조(오른쪽)를 사용합니다.

proposed_modification.png

변형된 ResNet 구조에 기반한 제안하는 방식은 아래와 같습니다 (왼쪽: 1st stage, 오른쪽: following stages). 즉, 1st stage에서  \theta_s, \theta_n, \theta_r 을 학습하고, 이 후 stages에서는  \theta_s, \theta_n 을 학습함과 동시에  \theta_o, \theta_r, Y_o 를 고정하여 1st stage에서 modeling한 feature space  Z 를 복원할 수 있도록 합니다. 이와같이 following stages 학습에서 feature space  Z 를 1st stage에서 학습된 space로 계속 한정하게 되면, feature extractor  f 가 새로운 data examples를 과거 data를 기억하고 있는 modeled space로 끌어당기는 효과를 기대할 수 있습니다.

proposed_real.png


Experiments

CIFAR-10/100 및 Chest X-rays [6] dataset을 통해 제안하는 방식의 feasibility를 확인했습니다. CIFAR-10/100는 50k의 training examples와 10k의 test examples로 구성되어있습니다. 실험을 위해, 50k의 training examples를 4 sets of 10k examples로 나누어 4 stages learning을 수행했고, 남은 10k examples를 validation set으로 하여 model selection을 수행하고, 10k test examples에 대한 classification error를 최종 성능평가 기준으로 하였습니다. 아래의 표는 각각의 방법론에 대한 5회 실험의 mean(std)를 정리한 결과입니다. 결과에서 확인할 수 있듯이, multi-center single-task learning에 적합하도록 수정된 LwF+/EWCLwF+ 방식이 original LwF/EWCLwF 방식 대비 약간 더 좋은 성능을 보이는 것을 확인할 수 있으며, stage-2 부터는 제안하는 방식이 다른 방식들 대비 좋은 성능을 보이는 것을 확인할 수 있습니다.

CIFAR.png

Stage-1 학습이 완료되면, stage-1의 training data는 다음 stages에서 활용되지 않습니다. 따라서 stage-4를 통해 학습된 최종 model로 stage-1의 training data에 대한 classification accuracy를 확인해 봄으로써 각 방식이 knowledge preservation 관점에서 얼마나 좋은 성능을 보이는지 확인할 수 있었고, 그 결과 아래와 같이 제안하는 방식이 knowledge preservation 관점에서 가장 좋은 성능을 보이는 것을 확인할 수 있었습니다.

exp_knowledge_preservation.png

두 번째 실험에 사용된 Chest X-ray (CXR) dataset [6]은 결핵 진단을 위해 수집된 dataset으로 3,556 abnormal, 6,952 normal 영상으로 구성되어있습니다. 전체 dataset을 20%  \times 4 , 10%, 10%으로 나누어 4 training stages, validation, test 에 활용했습니다. 5번의 실험에 대한 AUC (area under the ROC curve)의 mean(std)로 성능을 평가 했고 결과는 아래와 같습니다. CIFAR-10/100 실험결과와 비슷하게 제안하는 방식이 가장 좋은 성능을 보이는 것을 확인할 수 있습니다.

CXR.png

마찬가지로, knowledge preservation 관점에서 각 방식의 성능을 확인해보기 위해, stage-4 model로 stage-1 training data에 대한 AUC (및 ROC curve)를 확인해 본 결과, 아래와 같이 제안하는 방식이 가장 좋은 성능을 보이는 것을 확인할 수 있었습니다.

ROC.png


Conclusion

Deep learning에서 catastrophic forgetting 문제는 중요한 연구주제로 다뤄져 왔습니다. 특히, 병원/의료데이터 환경에서는 반드시 해결해야하는 문제 중 하나로 생각될 수 있습니다. 본 연구에서 제안하는 방법은 제한된 multi-center single-task learning 환경에서 catastrophic forgetting을 극복하는데 사용할 수 있는 방법이며, 좀 더 다양한 dataset 및 task 에서의 검증을 통해 practically 사용될 수 있는 방법이 될 수 있을거라 기대하고 있습니다.


Reference

[1] McCloskey, M. and Cohen, N.J., Catastrophic interference in connectionist networks: The sequential learning problem, Psychology of learning and motivation 24, 109-165 (1989)

[2] Girshick, R., Donahue, J., et al., Rich feature hierarchies for accurate object detection and semantic segmentation, In: CVPR (2014)

[3] Kirkpatrick, J., Pascanu, R., et al., Overcoming catastrophic forgetting in neural networks, In: PNAS (2017)

[4] Li, Z. and Hoiem, D., Learning without forgetting, In: ECCV (2016)

[5] Vincent, P., Larochelle, H., et al., Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion, JMLR 11, 3371-3408 (2010)

[6] Hwang, S., Kim H.E., et al., A novel approach for tuberculosis screening based on deep convolutional neural networks, In: SPIE medical imaging (2016)

Posted by:Hyo-Eun Kim

research scientist @ Lunit

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s