[Paper Review] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints 논문 리뷰

업데이트:

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (EMNLP 2023, Google Research)

[0] Abstract

  • Multi-query attention(MQA)는 단일 key-value head를 사용해 디코더 추론 속도를 크게 향상시킴. 하지만 품질 저하가 발생할 수 있고 더 빠른 추론을 위해 별도 모델을 학습시키는 것은 바람직하지 않을 수 있음.
  • GQA
    • (1) 기존 multi-head language model의 checkpoint를 MQA를 사용하는 모델로 5% 정도의 pretraining 계산량만 사용해서 uptrain하는 방법을 제안함.
    • (2) grouped-query attention(GQA)을 도입함. 이는 multi-query attention의 일반화 버전으로 intermediate(1개 초과, query head 수 미만)의 key-value head를 사용함.
    • (3) Uptrained GQA가 MQA와 비슷한 속도로 multi-head attention에 가까운 품질을 달성함을 보임.

즉, 기존의 MQA는 MHA보다 메모리를 절약할 수 있어서 좋았지만, 성능 안좋고 학습이 불안정하게 되었음. GQA는 MHA와 MQA 사이의 방법론. BUT

→ 품질은 MHA랑 비슷하고 → 속도는 MQA랑 비슷하다

[1] Introduction

  • autoregressive decoder inference은 디코더 가중치 및 모든 attention key와 value를 매 디코딩 스텝마다 로드해야 해서 transformer 모델의 bottleneck임.
  • Multi-query attention (MQA)으로 key와 value를 로드하는 대역폭을 크게 줄일 수 있음.
  • 그러나 MQA는 품질 저하와 학습 불안정성을 초래할 수 있고, 품질과 추론에 최적화된 별도 모델을 학습하는 것이 어려울 수 있음. 또한 많은 언어모델들이 MQA를 사용하지 않음.
  • 본 연구는 (1) MHA를 사용하는 체크포인트를 MQA를 사용하도록 작은 사전학습 계산량으로 uptrain하는 방법을 보이고, (2) MHA와 MQA의 중간인 GQA를 제안함.

[2] Method

2.1 Uptraining

  • Multi-head 모델을 multi-query 모델로 converting 하려면, (1) checkpoint를 변환한 후 (2) 새로운 structure에 adaptation 하도록 추가적인 pretraining을 해야함
    • (1) checkpoint 변환
      • 여러 head들의 key와 value projection 행렬을 mean poll 해서 single projection matrices로 변환
      • random 으로 key, value를 initialize하는 것보다 위의 방법이 더 좋음
    • (2) 변환된 체크포인트를 원래 pretraining 스텝의 일부(α)만큼 추가로 사전학습함.

2.2 Grouped-query attention

  • Query head를 G개 그룹으로 나누고 각 그룹이 단일 key head와 value head를 공유하는 방식
  • GQA-1은 MQA와 동일하고, GQA-H는 MHA와 동일함. 중간 수의 그룹을 사용하면 MQA보다 품질이 좋고 MHA보다 빠른 중간 모델이 됨.
  • GQA는 모델 크기가 클수록 더 효과적일 것으로 예상됨.

[3] Experiments

  • T5 Large, XXL 모델과 이를 GQA, MQA로 uptrain한 모델들을 다양한 요약, 번역, QA 데이터셋에서 평가함.
  • Uptrained MQA는 MHA-Large보다 품질이 높고 추론이 빠른 좋은 trade-off를 보임. GQA는 MQA와 비슷한 속도로 MHA-XXL에 준하는 성능을 보임.
  • Ablation으로 체크포인트 변환 방법, uptraining 비율, GQA 그룹 수의 영향을 분석함.

[4] Conclusion

  • Key-value 로딩에 의한 메모리 대역폭 병목을 multi-query attention이 품질 저하의 trade-off로 개선할 수 있음.
  • Multi-head 모델을 작은 pretraining 비용으로 multi-query로 전환하는 방법을 제안함.
  • Grouped-query attention을 도입해 MQA와 비슷한 속도로 MHA에 가까운 성능을 달성함.

댓글남기기