멀티헤드 어텐션

하나의 어텐션 헤드(attention head)는 한 번에 하나의 관계를 학습합니다. 여덟 개의 헤드는 여덟 개를 학습합니다. 헤드는 싸게 더할 수 있습니다. 더 많이 사용합니다.

유형: Build 언어: Python 선수 지식: Phase 7 · 02 (Self-Attention from Scratch) 예상 시간: 약 75분

문제

하나의 셀프 어텐션 헤드(self-attention head)는 하나의 어텐션 행렬(attention matrix)을 계산합니다. 그 행렬은 한 종류의 관계를 포착합니다. 보통 훈련 신호(training signal)에서 손실(loss)을 가장 줄이는 관계입니다. 데이터(data) 안에 주어-동사 일치(subject-verb agreement), 공지시(co-reference), 장거리 담화(long-range discourse), 구문 청킹(syntactic chunking)이 모두 얽혀 있다면 단일 헤드(single head)는 이것들을 하나의 소프트맥스 분포(softmax distribution)로 섞어 버리고 많은 신호(signal)를 잃습니다.

2017년 Vaswani 논문의 해결책은 여러 어텐션 함수(attention function)를 병렬로 실행하는 것이었습니다. 각 함수는 자기만의 Q, K, V 사영(projection)을 가지며, 출력을 연결(concatenate)합니다. 각 헤드는 d_model / n_heads 차원(dimension)의 더 작은 부분공간(subspace)에서 작동합니다. 전체 파라미터(total parameter) 수는 거의 유지되고 표현력(expressive power)은 올라갑니다.

멀티헤드 어텐션(Multi-head attention)은 2026년 기준 모든 트랜스포머(transformer)의 기본값입니다. 논쟁은 헤드를 몇 개 쓸지, 키(key)와 값(value) 사영을 공유할지, 즉 그룹드 쿼리 어텐션(Grouped-Query Attention, GQA), 멀티쿼리 어텐션(Multi-Query Attention, MQA), 멀티헤드 잠재 어텐션(Multi-head Latent Attention, MLA) 중 무엇을 택할지에 있습니다.

사전 테스트

2문제 · 이 강의를 시작하기 전에 얼마나 알고 있는지 확인해보세요

1.단일 어텐션 헤드(single attention head)가 여러 유형의 관계(예: 구문과 공지시)를 동시에 포착하기 어려운 이유는 무엇인가요?

2.d_model=512이고 헤드(head) 8개인 멀티헤드 어텐션(multi-head attention)에서 각 헤드 내 Q, K, V의 차원은 얼마인가요?

0/2 답변 완료

개념

multi-head attention — split, attend in parallel, concatenate X (N, d_model) @ W_q @ W_k @ W_v split into 4 heads Q (N, d_model) → reshape (4, N, d_head) Q_1 Q_2 Q_3 Q_4 K_1 K_2 K_3 K_4 V_1 V_2 V_3 V_4 softmax(Q_h K_h^T / √d_head) V_h — batched matmul, one op head 1 subject-verb head 2 positional head 3 copy / induction head 4 named entities concat → (N, d_model) @ W_o (d_model, d_model) output (N, d_model) W_o is where the heads mix. until W_o, every head lives in its own subspace.

분할(Split). 형상(shape)이 (N, d_model)X를 받습니다. Q, K, V를 각각 (N, d_model)로 사영(project)합니다. 그런 다음 d_head = d_model / n_heads일 때 (N, n_heads, d_head)로 재구성(reshape)하고 (n_heads, N, d_head)로 전치(transpose)합니다.

병렬로 참조(Attend in parallel). 각 헤드 안에서 스케일드 점곱 어텐션(scaled dot-product attention)을 실행합니다. 각 헤드는 (N, d_head)를 만듭니다. 헤드들은 임베딩(embedding)의 서로 다른 부분공간에서 작동하며 어텐션 계산 중에는 서로 대화하지 않습니다.

연결하고 사영(Concatenate and project). 헤드를 다시 (N, d_model)로 쌓고, 학습된 출력 행렬(learned output matrix) W_o (d_model, d_model)을 곱합니다. W_o는 헤드들이 섞이는 지점입니다.

왜 작동하는가. 각 헤드는 표현 예산(representational budget)을 두고 다른 헤드와 경쟁하지 않고 전문화(specialize)할 수 있습니다. 2019-2024년 탐침 연구(probing study)는 구분되는 헤드 역할(distinct head role)을 보여줍니다. 위치 헤드(positional head), 이전 토큰(previous token)에 참조하는 헤드, 복사 헤드(copy head), 개체명 헤드(named-entity head), 유도 헤드(induction head)처럼 역할이 나뉩니다. 유도 헤드는 문맥 내 학습(in-context learning)의 중요한 회로로 알려져 있습니다.

2026년 기준 변형 계보(variant lineage):

변형(Variant)Q 헤드K/V 헤드사용 모델
Multi-head (MHA)NNGPT-2, BERT, T5
Multi-query (MQA)N1PaLM, Falcon
Grouped-query (GQA)NG (예: N/8)Llama 2 70B, Llama 3+, Qwen 2+, Mistral
Multi-head latent (MLA)N저랭크(low-rank)로 압축DeepSeek-V2, V3

GQA는 N/G만큼 KV 캐시 메모리(KV-cache memory)를 줄이면서 품질을 거의 유지하기 때문에 현대적 기본값(modern default)입니다. MLA는 K/V를 잠재 공간(latent space)으로 압축(compress)하고 계산 시점(compute time)에 다시 사영해서 더 많은 메모리를 절약합니다. 대신 부동소수점 연산량(FLOPs)이 더 듭니다.

직접 만들기

Step 1: 단일 헤드 어텐션(single-head attention)에서 헤드 나누기

Lesson 02의 SelfAttention을 가져와 분할/연결 쌍(split/concat pair)으로 감쌉니다. 원문은 NumPy 구현을 언급하지만, 현재 저장소의 code/main.py는 표준 라이브러리만으로 같은 논리를 구현합니다. 논리는 다음과 같습니다.

def split_heads(X, n_heads):
    n, d = X.shape
    d_head = d // n_heads
    return X.reshape(n, n_heads, d_head).transpose(1, 0, 2)  # (heads, n, d_head)

def combine_heads(H):
    h, n, d_head = H.shape
    return H.transpose(1, 0, 2).reshape(n, h * d_head)

재구성 한 번과 전치 한 번입니다. PyTorch의 nn.MultiheadAttention도 내부적으로 같은 아이디어를 사용합니다.

Step 2: 헤드별 스케일드 점곱 어텐션 실행

각 헤드는 Q, K, V의 자기 슬라이스(slice)를 받습니다. 어텐션은 배치 행렬 곱(batched matmul)이 됩니다.

def mha_forward(X, W_q, W_k, W_v, W_o, n_heads):
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    Qh = split_heads(Q, n_heads)         # (heads, n, d_head)
    Kh = split_heads(K, n_heads)
    Vh = split_heads(V, n_heads)
    scores = Qh @ Kh.transpose(0, 2, 1) / np.sqrt(Qh.shape[-1])
    weights = softmax(scores, axis=-1)
    out = weights @ Vh                    # (heads, n, d_head)
    concat = combine_heads(out)
    return concat @ W_o, weights

실제 하드웨어(hardware)에서는 Qh @ Kh.transpose(...)가 하나의 bmm입니다. GPU는 (heads, N, d_head) × (heads, d_head, N) -> (heads, N, N) 형태의 배치 행렬 곱 하나를 봅니다. 헤드를 추가하는 비용이 상대적으로 낮은 이유입니다.

Step 3: 그룹드 쿼리 어텐션(Grouped-Query Attention) 변형

키와 값 사영만 바뀝니다. Q는 n_heads개 그룹(group)을 갖고, K와 V는 n_kv_heads < n_heads개 그룹만 만든 뒤 반복해서 맞춥니다.

def gqa_project(X, W, n_kv_heads, n_heads):
    kv = split_heads(X @ W, n_kv_heads)       # (kv_heads, n, d_head)
    repeat = n_heads // n_kv_heads
    return np.repeat(kv, repeat, axis=0)      # (n_heads, n, d_head)

추론(inference)에서는 KV 캐시에 n_heads개 복사본(copy)이 아니라 n_kv_heads개 복사본만 저장하므로 메모리를 아낍니다. Llama 3 70B는 64개 쿼리 헤드와 8개 KV 헤드를 사용해 캐시를 8배 줄입니다.

Step 4: 각 헤드가 무엇을 배웠는지 탐침(probe)하기

짧은 문장에 4-head MHA를 실행합니다. 각 헤드별 (N, N) 어텐션 행렬을 출력합니다. 무작위 초기화(random initialization)에서도 헤드마다 다른 구조를 볼 수 있습니다. 일부는 신호이고, 일부는 부분공간의 회전 대칭(rotational symmetry)에서 나온 효과입니다.

사용해보기

PyTorch에서는 한 줄입니다.

import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)

PyTorch 2.5+ 기준 GQA는 다음처럼 사용합니다.

from torch.nn.functional import scaled_dot_product_attention

# scaled_dot_product_attention은 CUDA에서 Flash Attention으로 자동 디스패치(auto-dispatch)된다.
# GQA에서는 Q의 shape가 (B, n_heads, N, d_head), K,V의 shape가
# (B, n_kv_heads, N, d_head)가 되도록 넣는다. PyTorch가 반복(repeat)을 처리한다.
out = scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)

헤드는 몇 개가 적절한가? 2026년 프로덕션 모델(production model) 기준 경험칙(rule of thumb)은 다음과 같습니다.

모델 크기(Model size)d_modeln_headsd_head
Small (~125M)7681264
Base (~350M)10241664
Large (~1B)204816128
Frontier (~70B)819264128

d_head는 거의 항상 64 또는 128에 모입니다. 한 헤드가 얼마나 볼 수 있는지의 단위입니다. 32 아래로 내려가면 sqrt(d_head) 스케일링 계수(scaling factor)와 헤드 용량(head capacity)이 부딪히기 시작하고, 256 위로 올라가면 "작은 전문가들이 많이 모여 있는(many small specialists)" 구조의 장점이 약해집니다.

산출물 만들기

outputs/skill-mha-configurator.md를 봅니다. 이 스킬(skill)은 파라미터 예산(parameter budget), 시퀀스 길이(sequence length), 배포 대상(deployment target)을 받아 새 트랜스포머의 헤드 수(head count), KV-헤드 수(KV-head count), 사영 전략(projection strategy)을 추천합니다.

연습문제

  1. 쉬움. code/main.py의 MHA에서 d_model=64를 고정하고 n_heads를 1에서 16으로 바꿉니다. 합성 복사 과제(synthetic copy task)에서 작은 1계층 모델(tiny one-layer model)의 손실을 그립니다(plot). 헤드가 늘면 도움이 됩니까, 정체(plateau)가 옵니까, 아니면 나빠집니까?
  2. 중간. MQA를 구현합니다. 모든 쿼리 헤드가 하나의 KV 헤드를 공유합니다. 전체 MHA 대비 파라미터 수(parameter count)가 얼마나 줄어드는지 측정하고, N=2048 추론에서 KV 캐시 크기가 얼마나 줄어드는지 계산합니다.
  3. 어려움. 멀티헤드 잠재 어텐션(Multi-head Latent Attention)의 작은 버전(tiny version)을 구현합니다. K,V를 rank-r 잠재 표현(latent)으로 압축하고, KV 캐시에는 잠재 표현을 저장한 뒤 어텐션 시점(attention time)에 복원(decompress)합니다. 어떤 r에서 캐시 메모리가 전체 MHA의 1/8 아래로 내려가면서 검증 ppl(validation ppl)이 1비트(bit) 이내로 유지됩니까?

핵심 용어

용어흔한 설명실제 의미
헤드(Head)"하나의 어텐션 회로(attention circuit)"자기 어텐션 행렬을 가진 d_head = d_model / n_heads 차원의 Q/K/V 사영 하나다.
d_head"헤드 차원(Head dimension)"헤드별 은닉 폭(hidden width)이다. 프로덕션에서는 거의 항상 64 또는 128이다.
분할/결합(Split / combine)"재구성 트릭(Reshape trick)"어텐션 전후에 (N, d_model) ↔ (n_heads, N, d_head)로 reshape+transpose하는 과정이다.
W_o"출력 사영(Output projection)"헤드 연결 뒤 적용하는 (d_model, d_model) 행렬이다. 헤드들이 섞이는 곳이다.
MQA"KV 헤드 하나"멀티쿼리 어텐션(Multi-Query Attention)이다. 하나의 공유 K/V 사영(shared K/V projection)을 사용해 KV 캐시는 가장 작지만 품질 손실이 있을 수 있다.
GQA"Llama 2 이후 기본값(default)"n_kv_heads < n_heads인 그룹드 쿼리 어텐션(Grouped-Query Attention)이다. K/V를 반복해 Q 헤드 수에 맞춘다.
MLA"DeepSeek의 기법"멀티헤드 잠재 어텐션(Multi-head Latent Attention)이다. K,V를 낮은 랭크 잠재 표현(low-rank latent)으로 압축하고 어텐션 시점에 복원한다.
유도 헤드(Induction head)"문맥 내 학습 뒤의 회로"이전 출현(occurrence)을 감지하고 그 뒤에 온 토큰을 복사하는 헤드 쌍(head pair)이다.

더 읽을거리

실습 코드

이 강의의 실습 코드 1개

main
Code

산출물

이 강의에서 생성된 프롬프트, 스킬, 코드 산출물 1개

mha-configurator

Recommend head count, KV-head count, and projection strategy (MHA / MQA / GQA / MLA) for a new transformer.

Skill

확인 문제

3문제 · 모두 맞추면 완료 표시가 가능합니다

1.멀티헤드 어텐션(multi-head attention) 파이프라인에서 개별 헤드들이 서로 정보를 교환하는 시점은 언제인가요?

2.Llama 3 70B는 64개의 쿼리(Query) 헤드와 8개의 KV 헤드만 사용합니다(GQA). 이 설계의 주요 이점은 무엇인가요?

3.학습된 모델에서 한 어텐션 헤드가 현재 토큰의 이전 출현 다음에 나온 토큰을 일관되게 복사하는 것이 관찰됩니다. 이 헤드 역할의 이름은 무엇이며 왜 중요한가요?

0/3 답변 완료

추가 문제 풀기

AI가 강의 내용을 바탕으로 새로운 문제를 생성합니다