[Paper Review] Llama Series 비교 (1 vs 2 vs 3) 및 리뷰
업데이트:
LlaMA 1 vs 2 vs 3
model | LLaMA 1 | LLaMA 2 | LLaMA 3 |
---|---|---|---|
released | 2023.02 | 2023.07 | 2024.04 |
pretraining data | 1.4T | 2T | 15T |
context length | 2K (2048) | 4K (4096) | 8K (8192) |
model size | 7B, 13B, 33B, 65B | 7B, 13B, 70B, (34B) | 8B, 70B, (400B 출시예정) |
architecture | MHA | MHA (7B, 13B), GQA (34B, 70B) | GQA |
tokenizer | BPE SentencePiece tokenizer | BPE SentencePiece tokenizer (total vocab: 32K, LlaMA 1이랑 동일) | tokenizer (total vocab: 128K) |
[1] Llama 1
[2] Llama 2: Open Foundation and Fine-Tuned Chat Models
In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models outperform open-source chat models on most benchmarks we tested, and based on our human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. We provide a detailed description of our approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on our work and contribute to the responsible development of LLMs.
2.1 Llama 2: Pretraining
[1] Datasets
Llama1 보다 40% 더 많은 데이터로 학습, context length가 길어짐
- Data: LLaMA 1 (1.4T) → LLaMA 2 (2T tokens)
- Context Length: LLaMA 1 (2K) → LLaMA 2 (4K, 4096)
[2] Training Details
- Architecture
- decoder-only transformer (LLaMA1과 동일)
- 최근 생성 테스크에는 주로 decoder-only model을 사용
- +) encoder only model (ex. BERT): input 에 대한 분석과 이해가 필요한 task에 적합 (ex. sentence classification, named-entity recognition)
- +) encoder-decoder models (like origin transformer): seq to seq model과 같은 테스크에 유리 (ex. translation, summerization)
- transformer 구조를 보면 input embedding 뿐만 아니라 output embedding 역시 필요한데, 이 구조 자체가 seq2seq 모델을 학습하도록 design된 설계라 그럼
- +) decoder-only model (ex. GPT, LLaMA, Claude, PaLM)
- encoder 모델은 주로 input을 “잘” embedding 하기 위해 필요했음. 최근 LLM 모델들은 input sequnce를 잘 처리하여 다음에 올 output sequence로 변환하는데(seq2seq)에 focus를 두지 않고, 주어진 text 다음에 올 단어를 예측하는 것에 focus를 두고 있기 때문에 decoding only model을 채택하고 있음 (auto-regressive하게 문장을 생성하는 형태)
- Pre-normalization (GPT3)
- SwiGLU activation function (PaLM)
- Rotary Embeddings (GPTNeo)
- decoder-only transformer (LLaMA1과 동일)
- LLaMA 1과 2의 차이는?
- 대체로 비슷한데, 34B, 70B 의 큰 모델에 대해서 Group Query Attention (GQA) 적용
- size가 큰 모델에 대해서는 GQA 를 도입하여 Inference 효율화
- GQA 논문 리뷰는 다음을 참고하세요
- LLaMA1과 동일하게 BPE sentencepiece tokenizer 사용 (total vocab: 32K tokens)
2.2 Llama2-Chat: FineTuning
- [1] pretraining data를 self-supervised learning 을 통해 학습한 LLaMA2 모델을
- [2] SFT (supervised finetuning) 하여 LLaMA2-Chat 모델을 학습
- [3] 이후 RLHF 방법론으로 반복적으로 튜닝하여 quality 를 높임
[3] Llama 3
Performance
3.1 Pretraining
Model architecture
- tokenizer (32K → 128K)
- 8B, 70B 모두 GQA 도입
- 작은 모델에도 GQA를 도입하는게 효과적이었나?
- context length (4K → 8K)
Training Data
- data size: 2T → 15T tokens (WoW)
- code data: 4배 더 많음
- for multilingual use cases
- pretraining data의 5% 이상의 데이터는 non-english data로 구성되어있음 (30개 이상의 langs)
- data-filtering pipelines
- These pipelines include using heuristic filters, NSFW filters, semantic deduplication approaches, and text classifiers to predict data quality
- Llama2 모델을 text-quality classifiers로 사용했다고 함
- 다양한 소스의 데이터
- a data mix that ensures that Llama 3 performs well across use cases including trivia questions, STEM, coding, historical knowledge, etc.
pretraining
- scaling behavior에 대해 더 많은 발견을 했다고 함
- Llama 1 논문에서는
- 기존 training compute에 대한 chinchilla-optimal amount은 8B param model을 200B token으로 학습시키는 거였는데,
- Llama 1 저자들은 비슷한 크기의 모델을 더 많이 학습시켰더니 훨씬 성능이 좋아짐을 발견했었음
- Llama 3 에서는 8B, 70B 모델에 대해 15T tokens 으로 학습을 했는데, 그 이후에 학습시켰을 때에도 log-linearly 하게 모델의 성능이 좋아졌음을 발견했다고 함
- Llama 1 논문에서는
- 학습 효율화를 해서 95% 이상의 훈련시간을 단축했다 함
- 3가지 병렬화 진행: data parallelization, model parallelization, and pipeline parallelization.
- 16K GPU로 학습할 때, GPU 당 400 TFLOPS 이상의 compute utilization 을 달성한게 BEST 였다고 함
- 2개의 24K GPU cluster를 custom 으로 구축하여 학습했다고 함
- To maximize GPU uptime, we developed an advanced new training stack that automates error detection, handling, and maintenance. We also greatly improved our hardware reliability and detection mechanisms for silent data corruption, and we developed new scalable storage systems that reduce overheads of checkpointing and rollback.
- → 95% 이상의 훈련시간을 단축. Llama 2보다 3배 더 효율적으로 학습했다고 함.
3.2 Instruction fine-tuning
Our approach to post-training is a combination of supervised fine-tuning (SFT), rejection sampling, proximal policy optimization (PPO), and direct preference optimization (DPO)
- Supervised Fine-Tuning, SFT: labeling된 데이터를 활용해 모델을 fine-tuning
- Rejection Sampling: 모델이 생성한 결과 중 품질이 낮은 것들을 제거
- Proximal Policy Optimization, PPO: 강화학습 방법론
- Direct Preference Optimization, DPO: human preference에 따라 모델을 최적화하는 방법론
SFT에 사용되는 prompts와 PPO/DPO에 사용되는 preference rankings의 품질이 aligned 모델의 성능에 커다란 영향을 미침. → 데이터 선별을 열심히 했다 함
특히 PPO와 DPO를 통한 preference rankings 학습은 Llama 3의 추론 및 코딩 능력을 크게 향상시켰다고 함.
댓글남기기