오늘 소개해 드릴 논문은 ‘Improving BERT Fine-Tuning via Self-Ensemble and Self-Distillation’입니다.
콥스랩(COBS LAB)에서는 주요 논문 및 최신 논문을 지속적으로 소개해드리고 있습니다.
해당 내용은 유투브 ‘딥러닝 논문읽기 모임’ 중 ‘Improving BERT Fine-Tuning via Self-Ensemble and Self-Distillation’ 영상 스크립트를 편집한 내용으로, 영상으로도 확인하실 수 있습니다. (영상링크: https://youtu.be/JI5kXF_OUkY)
fine-tuning이라는 개념은 이전 포스팅 한 ‘BERT’와 ‘How transferable are features in deep neural networks?’라는 논문에서도 등장한 개념입니다. 본 논문을 통해서, fine-tuning의 개념을 더 자세히 알아보겠습니다.
오늘 소개해 드릴 논문인 ‘Improving BERT Fine-Tuning via Self-Ensemble and Self-Distillation’에서 배경 논문이 되는 ‘BERT’라는 논문은 자연어 처리 분야에서 큰 획을 그은 논문입니다. BERT의 Fine-Tuning Pre-trained 모델을 토대로 최근의 연구되고 있는 자연어처리 논문들은 대체로 이 Pre-trained 모델의 구조나 Pre-trained 부분을 중점적으로 다루고 있습니다. BERT와 같은 Pre-trained language 모델은 큰 규모의 unlabeled 데이터를 사용해서 input 데이터를 general 한 벡터 형태로 변형시킬 수 있도록 학습한 후에, labeled 데이터를 사용해서 각 문장이 어느 클래스에 속하는지 학습하는 Fine-Tuning 과정이 있습니다
BERT 모델을 향상하는 데에는 모델이 구조를 바꾸거나, feature extraction을 하거나, data augmentation 등의 다양한 방법이 있지만, Fine-Tuning 그 자체에 방법을 바꾸는 것에는 아직 완벽하게 연구되지 않고 있습니다. 따라서 이 논문에서는 외부 데이터나 추가적인 knowledge를 사용하지 않고 fine-tuning만을 허용하여 BERT의 활용을 극대화하려는 방법을 연구하였습니다.
Fine-Tuning은 보통 SGD를 이용해서 weight값을 업데이트한 방법으로 학습이 진행됩니다. 하지만 Fine-Tuning을 random seed나 데이터에 순서에 취약해서 이 weight값이 변하면, 그 학습도 변할 수 있기 때문에 이러한 단점을 극복하고자 Ensemble을 이용해서 overfitting이 되거나 generalization을 높여서, 이런 random seed나 training 순서에 대한 민감성을 줄이도록 학습했습니다.
하지만 Ensemble 모델은 여러 개의 모델을 학습하고 그 모델에 대한 결과를 합하는 과정이라서, 시간과 리소스에 대한 비용이 많이 들 수밖에 없는 단점이 있습니다. 이러한 단점을 타파하기 위해, 그다음으로 이 논문에서 연구된 방법은 Self-Ensemble 모델입니다. Self Ensemble 모델은 기존의 Ensemble 모델에서처럼 아웃풋을 합치는 모델이 아니라, 각 모델의 파라미터들을 합치는 방법으로 각 모델의 아웃풋을 계산할 필요 없이 비용이 많이 드는 문제를 해결할 수 있습니다.
하지만, 이 문제는 Ensemble 모델의 기존의 BERT 베이스 모델에 Fine-Tuning 해주는 과정과 큰 차이가 없기 때문에, 그다음으로 고안한 방법은 Self Distillation입니다. Self Distillation에서 각 Fine-Tuning의 step마다 parameter를 Voted 합니다. 이렇게 투표(voted)하여 선정된 모델들을 ‘teacher’ 모델이라고 부르고 기존 모델을 ‘student’ 모델이라고 불러서, 두 개의 loss 값을 합쳐서 나온 값을 loss라고 판단하고, 그 loss에 대한 weight를 업데이트해주는 방식으로 학습을 진행합니다. 이러한 방식으로 다음 student 모델이 더 robust 하고 더 정확한 결과를 예측할 수 있다고 합니다. 또 이렇게 각 epoch 또는 batch가 진행될수록 student 모델이 개선되고, 그런 student 모델을 다시 또 averaging 하는 방식을 통해 teacher 모델도 개선이 된다고 합니다.
관련 연구에 관해서 살펴보겠습니다. 이 논문에서 두 개의 진행되는 과정인 Pre-trained language 모델과 knowledge Distillation에 대한 더 자세한 설명을 해보겠습니다. Pre-trained 모델은 Pre-trained language 모델 중에 이 논문에 사용된 모델은 ‘BERT’ 모델로, 대량의 다양한 분야의 unlabeled 문장 데이터를 사용해서 학습된 Pre-trained 모델입니다. BERT 베이스 모델은 열두 개의 Transformer encoder, 그리고 BERT 라지 모델은 24개의 Transformer encoder layer로 학습된 모델입니다. 이 두 개의 모델은 동일하게 512개의 토큰을 입력값으로 받고, CLS와 SEP라는 특수한 토큰을 이 입력값에 추가로 삽입해줍니다. CLS는 예를 들면 각 입력값에 시작에 추가가 되는데, classification task 같은 경우에는 입력값의 종류가 하나 이기 때문에, 문장에 첫 시작 부분에 한번 들어가고, 문장 비교와 같은 테스크에서는 문장 1과 문장 2가 각각 시작할 때 CLS가 두 번 들어갑니다. SEP는 각 입력값에 대해서 문장이 여러 개로 이뤄져 있을 때 문장을 구분하는데 추가로 입력됩니다. 이렇게 학습된 Pre-trained 모델을 문장에 입력하면 각 문장은 벡터 형태의 output을 가지게 되는데, 이 벡터에 대한 다양한 테스크를 추가 학습을 하는 과정을 Fine-Tuning이라고 합니다. 이렇게 이 Fine-Tuning 과정은 fully connected layer이기도 한데, 이런 output에 weight값을 곱해주고, 그 결과를 softmax로 결과를 계산합니다. 이 논문에서는 이런 Fine-Tuning을 응용해서 리소스가 제한된 환경에서, teacher 모델이나 teacher 모델을 활용해서, student 모델에 knowledge Distillation 시키는 방법을 사용합니다.
기존의 knowledge Distillation에서는 teacher 모델이 아주 완벽하게 학습되어있고, 그 완벽한 학습이 된 모델을 student 모델에게 전이시키는 방법으로 학습이 되었습니다. 하지만, 이 논문에서는 각 step마다 student 모델의 Ensemble로 teacher 모델을 만들고, 그 teacher 모델과 student 모델을 합친 결과로 다시 또 student 모델을 학습하는 방식으로 사용합니다.
다음으로, 각 모델들이 어떻게 구성이 되어있는지, Methodology 부분을 살펴보겠습니다. 기본적인 BERT의 Fine-Tuning은 BERT의 Pre-trained 모델의 마지막 hidden layer의 output, 즉 벡터 형태로 구성된 입력이 fully connected layer를 거치고 SGD를 거쳐서 weight를 업데이트합니다. 그리고 마지막으로 softmax output을 통해서, 각 class에 대한 probability를 가지게 되고, 그중에서 probability가 가장 높은 class가 output으로 계산되는 방식입니다.
하지만, 기본적인 BERT 모델은 seed size 혹은 데이터에 순서에 민감하기 때문에, 해당 문제점을 줄이고자 Ensemble 모델을 적용했습니다. 이 Ensemble 모델의 방법으로, 이 논문에서는 크게 두 가지 방법을 제안했습니다. 그 방법은 Voted 방법과 Average 방법입니다. Voted BERT라고 부르는 Voted 방법은 다양한 시드 값을 적용한 BERT 모델을 Fine-Tuning 시키고, 이렇게 나온 output을 모두 더해서, 가장 높은 probability를 가진 class를 최종 output으로 선정하는 것입니다. 하지만 이 방법은 여러 개의 모델을 Fine-Tuning 해야 하기 때문에, 학습이 오래 걸린다는 단점이 있습니다.
이러한 단점을 줄이기 위해, 다음으로 사용하는 방법은 Average BERT라는 모델의 output을 모델의 output이 아닌 모델의 parameter를 전부 모아서 평균을 낸 후, 그 평균을 낸 파라미터들을 이용해서 최종 output을 계산하는 방법입니다. 하지만 이 방법 역시 parameter값들을 전부 계산하기 때문에, Voted 보다는 계산 양이 적지만 여전히 계산 양이 적지 않다는 단점이 있습니다.
이러한 단점을 극복하고자, 저자가 고안한 다음 방법은 Self Ensemble 모델로, 여러 개의 모델 대신 gradient step을 이용하는 것입니다. gradient accumulate step 이라고도 부르는 이 방법은 일반적인 학습이 mini batch를 통해서 gradient를 구한 후, 그 결과의 weight를 즉시 업데이트하는 방법입니다. step은 mini batch를 통해서 구한 gradient를 t번의 step동안 누적시킨 후, 그것을 한 번에 업데이트시키는 방법입니다. 그래서 그 t step을 반복하여, 여러 개의 파라미터가 생기면 그 파라미터들의 평균을 구하고, 그 평균을 업데이트하는 방식으로 여러 번의 업데이트 과정을 거칠 수 있습니다. 이 방식으로, 여러 개의 모델을 계산할 필요가 없기 때문에 계산 양이 적지만 더 정확한 모델을 얻을 수 있습니다.
하지만, Self Ensemble 모델은 기존 Fine-Tuning 방법과 큰 차이가 없기 때문에, 이 논문에서는 Fine-Tuning을 극대화하고 하는 방법을 찾았고, 그래서 Self Distillation 방법을 사용했습니다. Self Distillation은 두 개의 teacher 모델과 student 모델의 결과를 합쳐 줌으로써, 더욱 robust 하고 개선된 학습이 가능하다고 합니다. 여기서 student 모델은 기본적인 BERT 베이스와 같은 모델로 Fine-Tuning을 학습한 결과입니다. 그리고, teacher 모델은 gradient accumulate step을 이용해서, 각 step들의 파라미터들을 모아서, Average이면 평균이고 Voted이면 파라미터들의 합을 얻은 결과입니다.
SDA모델은 teacher 모델을 계산하는데, 파라미터들의 평균을 계산한 모델입니다. 구체적인 방법은, Fine-Tuning에서 계산한 output을 Y값과 비교한 후에 얻은 cross entropy를, Fine-Tuning의 각 step에서 얻은 파라미터들을 합쳐서 평균을 얻은 결과와 기본적인 Fine-Tuning의 output의 결과를 비교한 값의 MSE를 구합니다. 그리고, 그 값에 hypter parameter lambda 값을 곱해 준 후에, cross entropy값과 MSE값을 더해서 얻은 loss 값을 토대로 weight를 업데이트하는 방식입니다. 이 방법을 통해서 학습이 진행될수록 cross entropy값을 계산하는 student 모델도 더 나은 결과를 보입니다. 그리고 student 모델의 파라미터들을 합해서 Average를 얻는 teacher 모델도 더 나은 결과를 얻는다고 합니다.
Self Distillation Voted 모델은 파라미터들의 평균이 아닌 파라미터들을 계산했을 때 나오는 최댓값을 가진 probability들을 모읍니다. 그리고 그 베이스 모델과 output을 비교하면서, MSE를 계산하고, 마찬가지로 그 cross entropy값과 더해주어서 weight를 업데이트하는 과정을 거칩니다. Voted BERT가 Average BERT보다 효율이 낮은 이유는, Self Distillation Voted 모델도 매번 output의 probability를 계산하는 과정까지 거쳐야 하기 때문에 계산이 추가로 좀 더 필요하고, 그래서 학습효율이 SDA보다는 조금 떨어집니다.
다음으로, experiment 부분에 대해서 설명드리겠습니다. 논문에서 모델을 학습하는 데 사용한 데이터는 크게 두 가지 종류로, text classification과 NLI가 있습니다.
text classification은 말 그대로 문장을 분류해주는 데이터들입니다. text classification에 속하는 데이터들은 크게 다섯 개입니다.
- IMDB 영화 긍정/부정 감정 리뷰 텍스트
- AG’s 뉴스: 세계/스포츠/ 비즈니스/ 과학
- DBPedia라는 14개의 클래스를 가진 각각의 클래스가 겹치지 않는 위키피디아에서 가져온 텍스트
- Yelp Polarity라는 Yelp라는 식당 리뷰 데이터에서 좋음과 나쁨으로 평가된 텍스트
- Yelp Full이라는 별이 5 개로 평가된 식당 리뷰 텍스트
NLI는 문장의 세트가 두 개가 주어지고, 그 문장의 관계에 대한 데이터입니다.
- SNLI는 스탠퍼드에서 구축한 문장 페어 관계 데이터로, 앞에 문장과 뒤 문장이 반대되거나 같은 내용이거나 관계없음이 labeled 되어 있습니다.
- MNLI는 10개의 장르에서 가져온 문장 페어 관계 데이터로, 동일하게 문장이 반대의 내용이거나 같은 내용이거나 관계없음 등이 labeled 되어 있습니다.
이 논문의 hypterparameter는 BERT 베이스 모델을 학습하는 데 사용한 동일한 hypterparameter를 사용했습니다.
그리고 이렇게 추가로 사용한 hyperparameter들을 비교해 봤을 때, IMDB 데이터에 다양한 BERT 베이스 모델의 람다 값에 변화를 주면서 비교했습니다. 람다 값이 1 일 때 오류 값이 가장 낮음을 확인할 수 있었습니다. 그리고 teacher 사이즈 K에 대해서는 SDE, SDA 둘 다 데이터별로 각각 다른 K를 가질 때, 가장 낮은 오차 값을 가진다는 것을 확인했습니다.
이제 모델을 비교해보겠습니다. 모든 모델을 비교하는 동안 동일한 seed size를 사용하였을 때, 베이스 BERT 모델보다 SDE나 SDA모델이 훨씬 나은 성능을 가진다는 것을 확인할 수 있습니다. 다만 seed가 K값 보다 더 많은 영향을 가지고 있기 때문에, 그리고 각각의 모델은 본인들이 더 맞는 seed가 있기 때문에, 시드들을 동일하게 했을 때 K 값은 큰 영향을 받지 않는다는 것을 확인할 수 있습니다.
Self Distillation의 효과를 좀 더 추가로 비교하기 위해서, 각 학습의 epoch를 추가할 때마다 test error값이 어떻게 변하는지 비교해 보았습니다. 위 그림의 빨간색이 베이스 모델이고, 베이스 모델은 epoch 3 이후로 test error rate가 크게 줄어들지 않는 반면에, 나머지 Self Distillation 모델은 그 후로 추가로 줄어드는 것을 확인할 수 있습니다. 그리고 앞에 언급한 변화가 Self Distillation의 효과인지 확인하기 위해서 이 모델을 학습하는 데 사용한 cross entropy, student 모델에서 얻은 loss값, 그 student 모델을 Average 한 파라미터 값으로 얻은 MSE값을 비교했습니다. 그 결과 cross entropy 값은 시작도 낮지만 변화가 없는 것에 비해서, MSE값은 처음에는 높았지만 급격하게 낮아지는 것을 확인할 수 있습니다. 이 결과를 통해서, 기존의 모델보다 SVD를 사용하는 방법이 조금 더 loss 값을 낮춰주는데 영향을 미친다는 것을 확인할 수 있습니다.
위 표에서, 이 논문에 등장한 모든 논문들을 각각의 데이터의 결과를 모두 비교했을 때, 데이터의 종류에 따라서 차이가 있지만, 전체적으로 Self Distillation 모델들이 기존의 BERT 모델보다 더 나은 오차 오류 값을 가진다는 것을 확인할 수 있습니다.
그리고 지금까지 모든 대부분의 모델은 BERT 베이스 모델을 활용했지만, BERT 라지 모델도 활용해서 SDA의 효과를 확인했을 때, SDA값 역시 기본 BERT 라지 모델보다 더 낮은 오차 값을 가진다는 걸 확인할 수 있었습니다. 그리고 BERT가 라지 경우에, classification은 K 값이 높아질수록 더 낮은 오차 갖는다는 것을 알 수 있습니다. 또한 NLI 테스크에서, K가 1일 경우에 더 높은 성능을 가지는 것을 확인할 수 있었습니다.
마지막으로 정리하자면, 이 논문의 목표는 추가적인 데이터와 지식 없이 BERT 모델을 Fine-Tuning을 통해서 개선시키는 것입니다. 그 방법으로, Self Ensemble과 Self Distillation을 활용했습니다. 결과적으로는 Self Ensemble 모델은 BERT를 개선시킬 수 있지만, 효율성이 좋지는 않았습니다. Self Distillation 모델은 Fine-Tuning에서 BERT 모델을 이전보다 발전시킬 수 있었습니다.
추가로, 이 논문의 저자는 data augmentation hyperparameter들의 개선을 통해서, 모델을 향상할 가능성을 확인했다고 합니다.
댓글