[액션파워 LAB] An Empirical Study of Example Forgetting during Deep Neural Network Learning (paper review)

ActionPower
10 min readJan 30, 2023

--

출처 : shutterstock

2019 ICLR에서 혜성(?) 같이 등장한 논문으로, training dynamics를 이용하여 example difficulty를 측정하는 첫 논문입니다.

국내에서는 많이 소개가 된 것 같지 않아서 액션파워에서 간단하게 공유드려 보고자 합니다.

1. Introduction

Continual learning task에서 neural network가 catastrophic forgetting 현상을 겪는다는 것은 이미 잘 알려져 있습니다.

(catastrophic forgetting : neural network가 새로운 task를 학습할 때, 전에 배웠던 정보를 잊는 것 — task는 말 그대로 새로운 task가 될 수도 있고 새로운 dataset이 될 수도 있습니다.)

그렇다면 동일한 task 내에서도 이와 비슷한 현상이 일어날까요?

저자들은 이러한 궁금증을 해결하기 위해 each mini-batch를 mini-task로 생각하여 실험을 진행하였으며, 그 결과 (catastrophic) example forgetting이라는 현상을 관측하였습니다.

저자들이 example forgetting을 어떻게 정의하였는지, 실험과 분석 결과를 간단하게 소개해봅니다.

2. Defining and Computing Example Forgetting

Catastrophic forgetting이 일어났는지를 확인하기 위해서, 새로운 new task를 배우기 전/후의 기존 task 성능을 비교합니다.

Mini-batch를 task라고 생각한다면, 특정 mini-batch를 학습하면서 다른 data point들의 성능을 측정해볼 수 있겠습니다.

그러나 이를 엄밀하게 구한다면 iteration마다 해당 mini-batch를 제외한 모든 data point들에 대하여 성능을 측정하는 것은 시간이 많이 걸리게 됩니다. 이러한 문제로 인해 mini-batch를 학습하기 전에 mini-batch에 해당하는 data point들에 대해서만 성능을 측정하게 됩니다.

Train dataset은 shuffling이 되어서 매 epoch마다 mini-batch가 바뀝니다. 이로 인해 관측 대상은 each data point (1 example 혹은 1 data)가 됩니다.

저자들은 classification task에서 모든 실험을 진행하기에, metric은 accuracy가 됩니다. 관측 대상이 1개의 data point이기에, accuracy는 0 혹은 1이 됩니다. 이는 epoch마다 측정을 하게 되며, 각 data point마다 epoch 길이 만큼의 accuracy vector를 가지게 됩니다.

말만 들었을 때는 무슨 소리인지 이해하기 어려울 수 있는데, 이를 pseudo code로 나타내면 다음과 같습니다.

Data point의 accuracy vector을 이용하여 저자들은 forgetting event, learning event라는 2가지 metric을 제안하였습니다.

- Forgetting event : acc_t > acc_t+1
- Learning event : acc_t < acc_t+1

Forgetting event가 발생하는 경우는 t epoch에서의 accuracy가 1이고 t+1 epoch에서의 accuracy가 0일 때 입니다. 즉 t epoch에서는 data point를 알고 있었는데, 다른 mini-batch들을 학습하면서 t+1 epoch에서는 data point를 까먹은 것이죠.

이와 반대로 learning event가 발생하는 경우는 t epoch에서의 accuracy가 0이고 t+1 epoch에서의 accuracy가 1일 때 입니다. 즉 t epoch에서는 data point를 모르고 있었는데, 다른 mini-batch들을 학습하면서 t+1 epoch에서는 data point를 학습한 것입니다!

3. Experiments and Analysis

3.1. Forgetting Events on MNIST, permutedMNIST, CIFAR-10

저자들은 computer vision에서 자주 사용되는 소규모 dataset인 MNIST, permutedMNIST, CIFAR-10에 대하여 training dataset에 대하여 forgetting event를 측정하였습니다.

실험 결과, 어떤 data는 forgetting event가 거의 일어나지 않은 반면 어떤 data는 forgetting event가 매우 많이 일어난다는 것을 확인할 수 있습니다.

저자들은 도대체 어떤 data가 forgetting event가 안일어나는지, 그리고 어떤 data가 forgetting event이 많이 일어나는지 궁금하여 이에 해당하는 data를 찾은 결과는 다음과 같습니다.

Forgetting event가 일어나지 않는 data는 상대적으로 물체가 중앙에 위치해있으며, class label의 특징이 전반적으로 포함되고 있음을 확인할 수 있습니다.

그에 반해 forgetting event가 자주 일어나는 data는 다른 label에서 볼 수 있을만한 atypical class characteristic을 가짐을 확인할 수 있습니다.

(예시 : forgettable — cat을 보면 강아지인지 고양이인지… 전 아직도 잘 모르겠습니다;)

3.2. Detecting of Noisy Examples

Noisy-labeled data는 해당 label이 noisy하기에 atypical class characteristic을 가지기에, forgetting event를 이용하여 noisy example을 detect할 수 있는지 저자들은 실험합니다.

(noisy-labeled data : cat image with truck label)

실험 결과, noisy-labeled dataset을 학습하여 forgetting event를 측정한 결과, noisy-labeled data는 무조건 forgetting event가 발생하며 기존의 regular examples과 다른 분포를 가진다는 것을 확인할 수 있습니다.

이 결과를 통해 ‘forgetting event가 많이 발생한다 → noisy-labeled data‘라고 하기에는 무리가 있습니다만, forgetting event가 많은 data들을 확인하여 noisy-labeled data를 filtering하는 방법을 통해 filtering 시간을 단축시킬 수 있습니다.

3.3. Removing Unforgettable Examples

저자들은 실험을 통해 동일한 task 내에서 example forgetting 현상이 일어남을 확인하였으니, 원래의 intuition인 catastrophic forgetting이 발생하는 continual learning과 비슷하게 셋업하여 실험을 진행합니다.

  1. Random split : random하게 dataset을 2개의 partition으로 나누어 번갈아가면서 학습합니다
  2. Split by Forgetting Event : forgetting event가 일어나지 않은 경우와 일어난 경우로 나누어 번갈아가면서 학습합니다

두 경우 모두 catastrophic forgetting 현상은 일어나나, 2번이 1번과 다른 특성이 나타난다는 것을 확인할 수 있습니다.

  • Forgetting event가 일어나지 않은 dataset을 학습할 때, forgetting event가 일어난 dataset의 accuracy가 크게 감소합니다 (random에 비해)
  • Forgetting event가 일어난 dataset을 학습할 때, forgetting event가 일어나지 않은 dataset의 accuracy는 거의 감소하지 않습니다 (random에 비해)

즉 이를 통해 forgetting event가 일어나지 않은 data는 forgetting event가 일어난 data에 비해 상대적으로 information 양이 적다는 것을 유추할 수 있습니다.

(information의 양이 많다면, 새로운 task를 학습하면서 기존의 information을 많이 잊어먹을 것이니)

저자들은 forgetting event가 일어나지 않은 data의 information 양이 적다는 것을 실험을 통해 보입니다.

Figure 5의 left plot에 대하여 설명하면 다음과 같습니다.

  • x축 : training dataset을 remove하는 비율
  • y축 : test dataset accuracy
  • 빨간색 선 : train dataset을 remove하지 않고 모두 사용하였을 때의 test accuracy (upper bound)
  • 파란색 선 : train dataset을 random remove하였을 때의 test accuracy
  • 초록색 선 : train dataset을 forgetting event가 적게 일어난 순서대로 remove하였을 때의 test accuracy

초록색 선이 x = 20%가 될 때에도 test accuracy의 성능이 낮아지지 않는다는 것을 확인할 수 있습니다.

즉 이를 통해 forgetting event가 발생하지 않은 data는 학습에 제외해도 generalization performance에 영향을 주지 않으며, 이는 곧 information 양이 적다는 것을 보여준다고 할 수 있습니다.

즉 forgetting event가 일어나지 않은 data들은 제거해도 된다!라고 생각할 수 있겠습니다.

4. Conclusion

논문에는 이외에도 다양한 실험들을 진행합니다.

(예를 들면 metric의 엄밀성에 대한 실험 — stability across seeds, forgetting by chance — 기존의 연구들과의 관련성 — support vectors -, 어떻게 하면 더 효율적으로 forgetting event를 구할 수 있을지 등…)

하지만 이를 다 다루기에는 너무 글이 길어지게 되어… 관심 있는 분들은 논문 링크를 달아둘테니, 읽어보시고 모르는 부분은 댓글을 남겨주시면 답변을 달도록 하겠습니다!

해당 논문이 나온 이후로 example difficulty (혹은 example importance)를 어떻게 하면 더 빨리 잘 측정할 수 있는지, dataset pruning, noisy label detection 등 다양한 방식으로 발전하고 있습니다.

기회가 된다면 다른 논문으로 또 찾아뵙도록 하겠습니다!

읽어주셔서 감사합니다.

Reference

[1] 2019 ICLR, An Empirical Study of Example Forgetting during Deep Neural Network Learning, https://arxiv.org/abs/1812.05159

--

--

ActionPower

Cutting-edge AI for the Benefit of the World. Unlock the potential of AI for a better tomorrow! For more details: actionpower.kr/en