본 포스트는 최근 발표된 새로운 normalization 기법인 Weight Standardization에 대해 소개합니다.

Introduction

Normalization은 머신러닝에서 데이터의 불필요한 정보를 제거하고 학습을 용이하게 하기 위해 매우 중요한 요소입니다. 특히 Batch Normalization (BN)은 딥러닝 모델의 학습을 안정시키고 가속화함으로써 성능을 크게 향상시켰고, ResNet을 비롯한 대부분의 state-of-the-art 모델에서 필수적인 요소로 사용되고 있습니다. 그러나 BN은 minibatch 단위로 normalization을 수행하기 때문에 모델의 성능이 large batch size에 의존하게 된다는 단점이 있습니다. 이러한 문제를 해결하기 위해 Group Normalization (GN) 등 다양한 기법이 제안되었지만, 일반적인 large-batch training 상황에서는 BN의 성능에 미치지 못하여 활용도가 현저히 떨어집니다.

이 논문에서 제안하는 Weight Standardization (WS) 은 GN과 같이 minibatch dependency를 완전히 제거하면서 large-batch 상황에서의 BN보다도 좋은 성능을 달성합니다 (그림 1).

스크린샷 2019-05-28 오후 1.31.14
그림 1. BN, GN, BN+WS의 성능 비교.

Background

Weight Standardization의 motivation을 이해하기 위해서는 먼저 BN에 대한 이해가 필요합니다. BN의 original paper를 포함한 기존의 연구들은 BN이 잘 되는 이유가 deep neural network 학습 시의 internal covariate shift (ICS)를 줄여주기 때문이라고 여겨왔습니다. 그러나 작년 NIPS에서 발표된 How Does Batch Normalization Help Optimization? 이라는 논문은 이러한 기존의 통념을 완전히 반박하는 결과를 제시하였습니다. 논문의 핵심을 요약하면, BN은 ICS와는 거의 관련이 없고, 대신 deep neural network의 optimization (loss & gradient) landscape를 훨씬 smooth하게 만들어 줌으로써 학습을 용이하게 한다는 것입니다 (그림 2).

landscape
그림 2. Visualization of smoothing effect on an optimization landscape.

이렇게 optimization landscape가 smooth해지면 다음과 같이 다양한 이점이 발생합니다.

  • Vanishing/exploding gradient를 예방하여 학습이 안정적으로 이루어집니다.
  • Gradient의 추정값이 reliable/predictive 해져서 global optimum을 더 잘 찾아갈 수 있습니다.
  • Learning rate 및 initialization과 같은 hyper-parameter에 robust 해집니다. 따라서 higher learning rate를 사용하여 학습을 빠르게 할 수 있습니다.
  • Sharp minima로 수렴할 위험이 줄어들어 generalization 성능에도 도움을 줍니다. 이에 관해서는 이전 포스트에서도 한번 다뤄진 적이 있습니다.

Weight Standardization은 BN의 이러한 효과에 착안하여, minibatch dependency 없이도 optimization landscape를 더욱 smoothing 할 수 있도록 고안되었습니다.

Method

Weight Standardization (WS)의 아이디어는 매우 간단합니다. Batch/Layer/Instance/Group Normalization 과 같은 기존의 기법들은 주로 feature activiation을 대상으로 normalization을 수행하는 반면, WS는 weight (convolution filter)을 대상으로 normalization을 수행합니다 (그림 3).

스크린샷 2019-05-28 오후 4.28.53.png
그림 3. Comparing normalization methods on activations and Weight Standardization.

즉, WS는 각 convolutional filter의 mean을 0, variance를 1로 조정해주게 됩니다. 이를 수식으로 표현하면 다음과 같습니다. Original filter weights를 W \in \mathbb{R}^{O \times I} (O: number of output channels, I: number of input channels \times kernel size)라 할 때, nomalize 된 filter weights \hat{W} \in \mathbb{R}^{O \times I}는 다음과 같이 계산되고, 이를 이용하여 최종적인 convolution을 수행하게 됩니다.

스크린샷 2019-05-28 오후 4.37.10

Property

앞서 설명드렸듯이 WS는 loss와 gradient의 landscape를 smoothing 하는 효과를 통해 큰 성능 향상을 가져옵니다. 이러한 smoothness를 formulate하기 위해 Lipschitzness라는 개념이 도입됩니다.스크린샷 2019-05-28 오후 3.35.30

이 때 L을 Lipschitz constant라고 부르며, 이 값이 작을수록 function f가 smooth해짐을 의미합니다. 또한 Lipschitz constant는 아래와 같이 gradient의 크기에 의해 좌우됩니다.스크린샷 2019-05-28 오후 4.01.00.png

따라서 loss (L)의 landscape를 smooth하게 만들기 위해서는 gradient (\Delta L)를 줄여야 하며, gradient의 landscape를 smooth하게 만들기 위해서는 gradient의 gradient (\Delta^2 L), 즉 Hessian (H)을 줄여야 합니다.

앞서 설명드린 How Does Batch Normalization Help Optimization? 논문은 BN이 activation에 대한 loss의 gradient (\Delta_x L) 및 Hessian (\Delta_x^2 L)을 줄이게 됨을 유도함으로써, BN이 optimization landscape를 smoothing 하는 효과가 있음을 보였습니다. 비슷한 맥락에서, 이 논문은 아래와 같이 WS가 weight에 대한 loss의 gradient (\Delta_w L) 및 Hessian (\Delta_w^2 L)을 직접적으로 줄이게 됨을 증명합니다. (자세한 유도과정은 생략하겠습니다.)ws_proof또한 이렇게 weight을 대상으로 normalization을 수행하는 것은 몇 가지 추가적인 장점이 있습니다.

  • Minibatch dependency가 없으므로 batch size와 완전히 무관하게 동작합니다.
  • CNN에서 weight은 activation에 비해 훨씬 용량이 적으므로 memory&time-efficient 하며, 특히 inference 시에는 weight이 fix되므로 추가 computation이 전혀 없습니다.
  • BN, GN 등 activation을 대상으로 한 normalization과 동시에 사용하여 성능을 더욱 향상시킬 수 있습니다.

Experiments

1. ImageNet Classification

아래 그림은 WS을 ImageNet classification에 적용했을 때의 결과를 보여줍니다. batch size=1일때, WS을 GN과 함께 사용하면 bath size=1일 때의 GN 뿐 아니라 batch size가 클때의 BN보다도 성능이 향상됩니다. 특히 ResNext-101과 같이 복잡한 모델에 대해서는 GN 잘 동작하지 않는데, WS를 함께 사용하면 성능을 크게 향상시킬 수 있습니다. 스크린샷 2019-05-28 오후 6.15.34

2. Detection & Segmentation

일반적으로 large batch를 사용할 수 있는 image classification 문제와 달리, detection/segmentation 문제에서는 이미지의 높은 해상도 및 모델의 복잡도로 인해 small batch를 사용할 수 밖에 없는 상황이 발생합니다.  따라서 이 논문은 이러한 상황에서 WS가 더욱 큰 효과를 발휘할 수 있음을 보여줍니다. 아래 그림은 Mask/Faster R-CNN에 WS를 적용했을 때의 detection/segmentation 결과이며, GN+WS를 사용하면 GN만 사용했을 때보다 성능이 크게 향상됨을 확인할 수 있습니다. 스크린샷 2019-05-28 오후 6.46.40

Conclusion

Weight Standardization (WS)은 매우 간단한 방법이지만 BN이 가지고 있는 minibatch dependency 문제를 효과적으로 해결하였고, 다양한 task에서 성능을 크게 향상시켰습니다. 사실 기존에도 Weight Normalization과 같이 비슷한 아이디어가 있긴 했지만, WS는 BN의 효과에 대한 theoretical understanding (landscape smoothing effect)를 기반으로 성능과 활용도를 크게 향상시켰다는 것에 더욱 의의가 있는 것 같습니다. 실제 구현도 기존의 convolutional layer에 코드 2~3줄만 추가하면 될 정도로 간단하고 추가적인 학습 trick이나 hyper-parameter tuning도 필요하지 않기 때문에, 다양한 상황에 쉽게 적용이 가능할 것으로 보입니다.

Posted by:hyeonseob

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s