그래디언트 체크포인팅(Gradient Checkpointing)과 활성값 재계산(Activation Recomputation)
역전파(Backpropagation; Backprop)는 모든 중간 활성값(intermediate activation)을 보관합니다. 70B 파라미터(parameter)와 128K 컨텍스트(context)에서는 랭크(rank)당 활성값이 3TB에 달합니다. 체크포인팅(checkpointing)은 부동소수점 연산량(FLOPs)을 메모리와 맞바꿉니다. 저장하는 대신 다시 계산합니다. 핵심 질문은 어떤 구간(segment)을 버릴 것인가이고, 답은 "전부"가 아닙니다.
유형: Build
언어: Python (with numpy, optional torch)
선수 지식: Phase 10 Lesson 04 (Pre-Training Mini-GPT), Phase 10 Lesson 05 (Scaling & Distributed)
예상 시간: 약 70분
문제
트랜스포머(Transformer)를 학습할 때 각 층(layer)은 역방향(backward) 과정에서 미분되는 모든 연산(operation)의 입력을 저장합니다. 어텐션 입력(attention input), Q/K/V 투영(projection), 소프트맥스 출력(softmax output), 피드포워드 망(Feed-Forward Network; FFN)의 입력, 정규화(norm) 출력, 잔차 스트림(residual stream)이 모두 포함됩니다. 은닉 차원(hidden size) d, 시퀀스 길이(sequence length) L, 배치(batch) B인 층에서 이는 층당 대략 12 * B * L * d 부동소수점 값(floats) 수준입니다.
d=8192, L=8192, B=1이라면 BF16 기준으로 층당 800MB입니다. 64층 모델은 활성값만 51GB이고, 이는 마이크로배치(microbatch) 크기를 곱하기 전이며, 어텐션 소프트맥스 중간값(L^2 per head)을 더하기 전이고, 텐서 병렬(tensor parallel) 부분 사본(partial copy)을 고려하기 전입니다.
청구서는 양쪽에서 날아옵니다. BF16 가중치(weight)와 옵티마이저 상태(optimizer state)는 80GB 안에 들어갈 수 있지만 활성값이 그 한계를 넘깁니다. 그래디언트 체크포인팅, 다른 이름으로 활성값 재계산은 표준 해법입니다. 대부분의 활성값을 버리고, 역방향 단계에서 순방향(forward)을 다시 실행해 되찾는 방식입니다. 비용은 추가 FLOPs이고, 이득은 체크포인트 구간 수와 전체 층 수의 비율만큼 메모리가 줄어드는 것입니다.
순진하게 적용하면 체크포인팅은 스텝(step)마다 순방향 FLOPs를 약 33% 더 요구합니다. 잘 적용하면, Korthikanti et al.의 "smart selection" 처럼 선택적 체크포인팅(selective checkpointing)을 사용하면 5배의 메모리 절감을 5% 미만의 FLOP 오버헤드(overhead)로 달성할 수 있습니다. FP8 행렬 곱셈(matmul), FSDP 오프로드(offload), 전문가 병렬 MoE(expert-parallel Mixture-of-Experts; MoE)가 함께 있는 환경에서는 이 차이가 실제로 중요합니다. 메모리도, 낭비되는 연산도 감당할 수 없기 때문입니다.
개념
역방향이 실제로 필요로 하는 것
output = layer(input)이 주어졌을 때, 역방향은 grad_input과 grad_params를 원합니다. 이를 계산하려면 다음이 필요합니다.
input. 선형 층(linear layer)에서는 grad_params = input.T @ grad_output을 계산하기 위해 필요합니다.
- 활성값 미분 중간값(activation derivative intermediate). ReLU/GELU/소프트맥스의 도함수(derivative)는 활성값에 의존합니다.
순방향 단계는 자동 미분 그래프(autograd graph)에 이를 자동으로 저장합니다. 모든 tensor.retain_grad()와 입력이 필요한 모든 연산은 해당 텐서(tensor)에 대한 참조(reference)를 유지합니다.
순진한 전체 체크포인팅(naive full checkpointing)
신경망(network)을 N개의 구간으로 나눕니다. 순방향 중에는 각 구간의 입력만 저장합니다. 역방향이 중간값을 필요로 하면 해당 구간의 순방향을 다시 실행해 구체화(materialize)한 뒤 미분합니다.
예를 들어 32층 트랜스포머를 1개 층씩 32개 구간으로 나눠 봅시다.
- 메모리: 32개의 층-입력만 저장합니다. 이는 작은 양입니다. 반면 모든 활성값을 저장하면
32 * 층당 활성값 부피(activation volume per layer)가 되어 매우 큽니다.
- 추가 연산: 구간마다 추가 순방향 1번씩, 전체 순방향 FLOPs가 약 33% 증가합니다. 역방향은 순방향의 2배이므로 전체 스텝은 기존
1 + 2 = 3 단위에서 1 + 1 + 2 = 4 단위가 됩니다.
이것이 Chen et al. 2016의 원래 처방입니다. 메모리와 연산을 균형 있게 맞추기 위해 sqrt(L) 층마다 체크포인트 하나를 둡니다. L=64라면 체크포인트 8개입니다.
선택적 체크포인팅 (Korthikanti 2022)
모든 활성값이 같은 비용을 가지지는 않습니다. 어텐션 소프트맥스 출력은 B*L*L*heads이고 시퀀스 길이에 대해 *제곱(quadratic)*으로 증가합니다. FFN의 은닉 활성값(hidden activation)은 B*L*4d이고 선형(linear)적으로 증가합니다. 긴 시퀀스에서는 소프트맥스가 지배적입니다.
선택적 체크포인팅은 저장하기 저렴한 활성값(선형 투영, 잔차)은 보관하고, 비싼 활성값(어텐션)만 다시 계산합니다. 최소한의 FLOPs로 재계산하면서 O(L^2) 메모리를 아낄 수 있습니다.
Megatron-Core는 이를 선택적 활성값 재계산("selective" activation recomputation) 모드로 구현합니다. 2024년 이후 최전선 학습(frontier training) 실행에서 표준적으로 사용됩니다.
오프로드(Offload)
재계산의 대안은 활성값을 순방향과 역방향 사이에 CPU RAM으로 보내는 것입니다. PCIe 대역폭(bandwidth)이 필요하며, 유휴 대역폭(idle bandwidth)이 재구체화(rematerialization) 비용을 넘을 때 유리합니다. 혼합 전략도 흔합니다. 일부 층은 체크포인트하고 일부는 오프로드합니다.
FSDP2는 오프로드를 1등급(first-class) 선택지로 제공합니다. GPU가 메모리에 병목(bottlenecked)되어 있지만 CPU-GPU 전송에는 여유(headroom)가 있을 때 오프로드가 빛을 발합니다.
재계산 비용 모델(recompute cost model)
L개 층 중 k 층마다 순진한 체크포인팅을 적용할 때, 스텝당 FLOPs는 다음과 같습니다.
flops_fwd_normal = L * f_layer
flops_bwd_normal = 2 * L * f_layer
flops_total_normal = 3 * L * f_layer
flops_fwd_ckpt = L * f_layer
flops_recompute = L * f_layer # 구간 안의 층마다 추가 순방향 1회
flops_bwd_ckpt = 2 * L * f_layer
flops_total_ckpt = 4 * L * f_layer
overhead = 4 / 3 - 1 = 0.33 = 33%
선택적 체크포인팅에서는 전체 층이 아니라 어텐션 커널(attention kernel)만 다시 계산합니다.
flops_recompute_selective = L * f_attention ~= L * f_layer * 0.15
overhead_selective = (3 + 0.15) / 3 - 1 = 0.05 = 5%
메모리 절감 모델(memory savings model)
층당 활성값 부피를 A라고 합시다. L개 층의 전체 활성값 메모리는 L * A입니다.
전체 체크포인트(구간 크기 1)는 L * input_volume만 저장합니다. 표준 트랜스포머에서는 입력 부피가 대략 1/10 A이므로 큰 폭으로 줄어듭니다(약 9 * L * A * 1/10을 절감).
k층마다 체크포인트하면 L/k * A와 활성 구간(active segment) 안의 k-1개 층 분량을 저장합니다.
k = sqrt(L)에서는 메모리와 재계산 비용이 둘 다 sqrt(L)에 비례합니다. 비용이 균일한 층에 대한 최적의 절충(optimal tradeoff)입니다.
체크포인트하지 말아야 할 때
- 파이프라인 단계(pipeline stage) 안에서 이미 실행 중인(in-flight) 최내부 층(innermost layer). 어차피 끝까지 실행해야 합니다.
- 단계의 연산을 지배하는 첫 번째와 마지막 층. 트랜스포머에서는 드뭅니다.
- 이미 FlashAttention을 쓰는 어텐션 커널. FlashAttention은 소프트맥스를 빠르게 재계산하므로, 추가적인 층 단위 체크포인팅이 주는 이득이 작습니다.
구현 패턴(implementation patterns)
-
함수 래퍼(function wrapper): 구간을 torch.utils.checkpoint.checkpoint(fn, input)으로 감쌉니다. PyTorch는 input만 저장하고 역방향에서 나머지를 재계산합니다.
-
데코레이터 기반(decorator-based): 층에 체크포인트 가능(checkpointable) 라벨을 붙입니다. 트레이너(trainer)가 설정 시점(config time)에 어떤 구간을 감쌀지 결정합니다.
-
수동 명시적 재계산(manual explicit recompute): 역방향 단계를 직접 작성하고, 저장된 입력으로 순방향을 복제하는 사용자 정의 recompute_forward를 호출합니다.
세 방식 모두 기능적으로 같은 결과를 줍니다. 래퍼 방식이 표준적인 관용구(idiom)입니다.
텐서 병렬(TP) / 파이프라인 병렬(PP) / FP8과의 상호작용
- 텐서 병렬(Tensor Parallel; TP): 체크포인트 입력은 재계산 중에 수집(gather)되거나 재분산(rescatter)되어야 합니다. 통신 비용(communication cost)을 함께 처리해야 합니다.
- 파이프라인 병렬(Pipeline Parallel; PP): 전형적인 패턴은 각 파이프라인 단계의 순방향을 체크포인트해서, 역순(reverse-order) 마이크로배치가 활성값 메모리를 재사용하도록 하는 것입니다.
- FP8 재계산: 재계산 중에 갱신되는 amax 이력(amax history)이 원래 순방향과 일치해야 합니다. 그렇지 않으면 FP8 스케일(scale)이 표류(drift)합니다. 대부분의 프레임워크는 스케일 값을 스냅샷(snapshot)으로 고정합니다.
직접 만들기
Step 1: 구간을 가진 장난감 모델(toy model)
import numpy as np
def linear_forward(x, w, b):
return x @ w + b
def relu(x):
return np.maximum(x, 0)
def layer_forward(x, w1, b1, w2, b2):
h = relu(linear_forward(x, w1, b1))
return linear_forward(h, w2, b2)
def model_forward(x, params):
activations = [x]
h = x
for w1, b1, w2, b2 in params:
h = layer_forward(h, w1, b1, w2, b2)
activations.append(h)
return h, activations
Step 2: 모든 활성값을 필요로 하는 순진한 역방향(naive backward)
def model_backward(grad_output, activations, params):
grads = [None] * len(params)
g = grad_output
for i in range(len(params) - 1, -1, -1):
w1, b1, w2, b2 = params[i]
x_in = activations[i]
h_pre = linear_forward(x_in, w1, b1)
h = relu(h_pre)
gh = g @ w2.T
gw2 = h.T @ g
gb2 = g.sum(axis=0)
g_pre = gh * (h_pre > 0)
gx = g_pre @ w1.T
gw1 = x_in.T @ g_pre
gb1 = g_pre.sum(axis=0)
grads[i] = (gw1, gb1, gw2, gb2)
g = gx
return g, grads
Step 3: k 층마다 체크포인트하는 메모리 구조(checkpoint-every-k memory)
def model_forward_checkpointed(x, params, k=4):
saved_inputs = [x]
h = x
for i, (w1, b1, w2, b2) in enumerate(params):
h = layer_forward(h, w1, b1, w2, b2)
if (i + 1) % k == 0:
saved_inputs.append(h)
return h, saved_inputs
def model_backward_checkpointed(grad_output, saved_inputs, params, k=4):
grads = [None] * len(params)
g = grad_output
segments = [(j * k, min((j + 1) * k, len(params))) for j in range(len(saved_inputs))]
for seg_idx in range(len(saved_inputs) - 1, -1, -1):
start, end = segments[seg_idx]
if start >= end:
continue
x_in = saved_inputs[seg_idx]
_, seg_acts = model_forward(x_in, params[start:end])
g, seg_grads = model_backward(g, seg_acts, params[start:end])
for j, gr in enumerate(seg_grads):
grads[start + j] = gr
return g, grads
Step 4: 비용 모델(cost model)
def checkpoint_cost(n_layers, segment_size, flops_per_layer=1.0):
fwd = n_layers * flops_per_layer
recompute = n_layers * flops_per_layer
bwd = 2 * n_layers * flops_per_layer
return {
"fwd": fwd,
"recompute": recompute,
"bwd": bwd,
"total": fwd + recompute + bwd,
"overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
}
def selective_checkpoint_cost(n_layers, attention_fraction=0.15,
flops_per_layer=1.0):
fwd = n_layers * flops_per_layer
recompute = n_layers * attention_fraction * flops_per_layer
bwd = 2 * n_layers * flops_per_layer
return {
"fwd": fwd,
"recompute": recompute,
"bwd": bwd,
"total": fwd + recompute + bwd,
"overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
}
Step 5: 메모리 추정기(memory estimator)
def activation_memory_mb(n_layers, hidden=8192, seq=8192,
batch=1, bytes_per_value=2):
per_layer = 12 * batch * seq * hidden * bytes_per_value
return n_layers * per_layer / 1e6
def memory_after_checkpoint(n_layers, segment_size, hidden=8192,
seq=8192, batch=1, bytes_per_value=2):
n_seg = max(1, n_layers // segment_size)
saved = (n_seg + segment_size) * 1 * batch * seq * hidden * bytes_per_value
return saved / 1e6
Step 6: 최적 구간 크기(optimal segment size)
def optimal_segment(n_layers):
return int(round(np.sqrt(n_layers)))
Step 7: 선택적 체크포인트 결정(selective checkpoint decision)
def should_recompute(layer_type, activation_bytes, recompute_flops_ratio):
if layer_type == "attention" and activation_bytes > 100 * 1e6:
return True
if layer_type == "ffn" and activation_bytes > 500 * 1e6:
return recompute_flops_ratio < 0.1
return False
사용해보기
torch.utils.checkpoint: from torch.utils.checkpoint import checkpoint는 PyTorch의 표준(canonical) 래퍼입니다. 함수를 감싸서 입력만 저장하고 역방향에서 나머지를 재계산합니다.
- Megatron-Core 활성값 재계산(activation recomputation):
selective, full, block 모드를 지원합니다. 2024년 이후 최전선 학습의 표준입니다.
- FSDP2 오프로드(offload): FSDP2의
module.to_empty(device="cpu")와 offload_policy는 활성값을 재계산하는 대신 CPU로 분산/오프로드합니다.
- DeepSpeed ZeRO-Offload: 옵티마이저 상태와 활성값을 위한 CPU 오프로드입니다. 체크포인팅을 보완합니다.
산출물 만들기
이 lesson은 outputs/prompt-activation-recompute-policy.md를 만듭니다. 모델 설정(layers, hidden, seq, batch)과 사용 가능한 GPU 메모리를 입력받아 층별 재계산 정책(none / selective / full / offload)을 제안하는 프롬프트(prompt)입니다.
연습문제
-
정확성(correctness)을 검증합니다. model_forward + model_backward(전체 활성값)과 model_forward_checkpointed + model_backward_checkpointed(구간 기반)을 실행합니다. 파라미터 그래디언트는 기계 정밀도(machine precision)까지 동일해야 합니다.
-
구간 크기 k를 1부터 L까지 훑어봅니다(sweep). FLOP 오버헤드와 메모리를 그립니다(plot). 곡선의 무릎(knee) 지점을 찾습니다.
-
선택적 체크포인팅을 구현합니다. 어텐션 모듈 입력은 저장하되 그 중간값은 저장하지 않습니다. 시퀀스 길이 8192의 32층 모델에서 전체 층 체크포인팅 대비 FLOP 오버헤드를 측정합니다.
-
오프로드를 추가합니다. 구간 입력을 모의 "CPU 버퍼(buffer)", 즉 별도 list에 저장합니다. "PCIe 대역폭"을 바이트/시간(bytes/time)으로 측정하고 오프로드와 재계산 사이의 손익분기점(breakeven point)을 찾습니다.
-
실제 PyTorch 트랜스포머를 torch.utils.checkpoint 사용/미사용으로 벤치마크(benchmark)합니다. 메모리(torch.cuda.max_memory_allocated)와 스텝 시간을 측정합니다.
핵심 용어
| 용어 | 흔한 설명 | 실제 의미 |
|---|
| 그래디언트 체크포인팅(Gradient checkpointing) | "순방향을 다시 해서 메모리를 절약" | 구간 입력만 저장하고 역방향 중 중간값을 재계산해 그래디언트 계산에 필요한 텐서(gradient-support tensor)를 복원한다. |
| 활성값 재계산(Activation recomputation) | "체크포인팅과 같은 것" | 같은 기법의 고성능 컴퓨팅(High-Performance Computing; HPC) 쪽 이름이다. |
구간 크기(Segment size, k) | "체크포인트당 층 수" | 중간값을 버렸다가 함께 재구체화하는 층의 수이다. |
| 선택적 체크포인팅(Selective checkpointing) | "Korthikanti의 트릭" | 어텐션 소프트맥스처럼 저장 비용이 큰 활성값만 재계산하고, 저렴한 것은 보관한다. |
| 전체 체크포인팅(Full checkpointing) | "순진한 버전" | 모든 구간에서 모든 층의 중간값을 재계산한다. |
| 블록 체크포인팅(Block checkpointing) | "거친 단위(coarse-grained)" | 트랜스포머 블록 전체를 체크포인트한다. 가장 큰 단위(granularity)이다. |
| FLOP 오버헤드 | "연산 세금(compute tax)" | 스텝당 추가 FLOPs = (재계산 FLOPs) / (순방향 + 역방향 FLOPs). 순진한 방식은 33%, 선택적 방식은 5%이다. |
| 활성값 오프로드(Activation offload) | "CPU로 보내기" | 순방향과 역방향 사이에 활성값을 CPU RAM으로 이동한다. 재계산의 대안이다. |
| sqrt-L 규칙(sqrt-L rule) | "고전적 최적(classical optimum)" | 비용이 균일한 층에서, 최적 체크포인트 간격은 sqrt(L) 층이다. |
| 어텐션 소프트맥스 부피(Attention-softmax volume) | "O(L^2) 문제" | L^2 * heads * batch floats. 긴 컨텍스트에서 활성값 메모리를 지배한다. |
더 읽을거리