GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai
2023
13 references

Abstract

Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference. We (1) propose a recipe for uptraining existing multi-head language model checkpoints into models with MQA using 5% of original pre-training compute, and (2) introduce grouped-query attention (GQA), a generalization of multi-query attention which uses an intermediate (more than one, less than number of query heads) number of key-value heads. We show that uptrained GQA achieves quality close to multi-head attention with comparable speed to MQA.

3 repositories
11 references

Code References

huggingface/transformers
4 files
src/transformers/models/longcat_flash/configuration_longcat_flash.py
1
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
src/transformers/models/qwen3_next/configuration_qwen3_next.py
1
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py
1
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py
1
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
onnx/onnx
4 files
docs/Changelog.md
2
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
docs/Operators.md
1
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
onnx/defs/nn/defs.cc
1
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
onnx/defs/nn/old.cc
1
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
pytorch/pytorch
2 files
torch/nn/functional.py
1
https://arxiv.org/pdf/2305.13245
torch/onnx/ops/__init__.py
1
2. Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
Link copied to clipboard!