개념
JAX의 철학
JAX는 함수형 프레임워크(Functional Framework)입니다. 클래스(class)도, 변경 가능한 상태(mutable state)도, .backward() 메서드(method)도 중심이 아닙니다.
| PyTorch | JAX |
|---|
상태를 가진 nn.Module 클래스 | 순수 함수: f(params, x) -> y |
loss.backward() | jax.grad(loss_fn)(params, x, y) |
| 즉시 실행(Eager Execution) | XLA 기반 JIT 컴파일 |
for x in batch: 수동 반복문 | jax.vmap(f) 자동 벡터화 |
DataParallel / FSDP | jax.pmap(f) 자동 병렬화(parallelism) |
변경 가능한 model.parameters() | 불변(immutable) 배열의 파이트리(pytree) |
이것은 취향 문제가 아니라 컴파일러(compiler) 제약입니다. JIT 컴파일은 순수 함수가 필요합니다. 같은 입력은 항상 같은 출력을 내야 하고, 부작용(side effect)이 없어야 합니다. 이 제약이 100배 속도 향상의 기반이 됩니다.
jax.numpy: 익숙한 표면
JAX는 가속기(accelerator) 위에서 동작하는 NumPy API를 제공합니다.
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)
함수 이름, 브로드캐스팅(broadcasting) 규칙, 슬라이싱(slicing) 의미는 NumPy와 비슷합니다. 하지만 배열(array)은 GPU/TPU 위에 있을 수 있고, 모든 연산은 컴파일러가 추적할 수 있습니다.
중요한 차이는 JAX 배열이 불변(immutable)이라는 점입니다. a[0] = 5처럼 수정하지 않습니다. 대신 a = a.at[0].set(5)를 씁니다. 처음에는 어색하지만, 이런 불변성(immutability)이 grad, jit, vmap 같은 변환(transformation)을 조합 가능하게 만듭니다.
jax.grad: 함수형 자동 미분
PyTorch는 기울기를 텐서의 .grad에 붙입니다. JAX는 기울기를 함수에 붙입니다.
import jax
def f(x):
return x ** 2
df = jax.grad(f)
df(3.0)
jax.grad는 함수를 받아 기울기를 계산하는 새 함수를 반환합니다. .backward() 호출도, 텐서에 저장되는 계산 그래프(computation graph)도 없습니다. 반환된 기울기 함수도 다시 호출하거나, 조합하거나, JIT 컴파일할 수 있는 또 하나의 함수일 뿐입니다.
d2f = jax.grad(jax.grad(f))
d2f(3.0)
2차 도함수(derivative), 3차 도함수, 야코비안(Jacobian), 헤시안(Hessian)을 모두 grad 조합으로 표현할 수 있습니다. PyTorch에서도 가능하지만 JAX에서는 이것이 핵심 토대입니다.
제약도 있습니다. grad는 순수 함수에 잘 맞습니다. 추적(tracing) 중 실행되는 print, 외부 상태(state) 변경, 명시적 키(key) 없는 난수 생성(random generation)은 피해야 합니다.
jit: XLA로 컴파일하기
@jax.jit
def train_step(params, x, y):
loss = loss_fn(params, x, y)
return loss
fast_step = jax.jit(train_step)
첫 호출에서 JAX는 함수를 추적합니다. 실제 계산을 수행하기보다 어떤 연산이 일어나는지 기록합니다. 그 추적 결과를 XLA(Accelerated Linear Algebra)에 넘기면 XLA가 연산 융합(operation fusion), 불필요한 메모리 복사(memory copy) 제거, 최적화된 기계어 생성을 수행합니다.
이후 호출에서는 Python을 건너뛰고 컴파일된 코드가 가속기에서 C++에 가까운 속도로 실행됩니다.
JIT가 도움이 되는 경우는 반복되는 학습 단계, 같은 모델의 추론(inference), 비슷한 모양(shape)의 입력으로 여러 번 호출되는 함수입니다. 반대로 값에 의존하는 Python 제어 흐름(control flow), 한 번만 실행하는 계산, 디버깅에는 불리할 수 있습니다. JIT 안에서는 if/else 대신 jax.lax.cond, 반복문 대신 jax.lax.scan 같은 기본 연산(primitive)을 사용해야 할 때가 있습니다.
vmap: 자동 벡터화
하나의 예시를 처리하는 함수를 작성합니다.
def predict(params, x):
return jnp.dot(params['w'], x) + params['b']
vmap은 이 함수를 배치를 처리하는 함수로 끌어올립니다.
batch_predict = jax.vmap(predict, in_axes=(None, 0))
in_axes=(None, 0)은 params는 배치 차원으로 묶지 않고 공유하며, x의 0번 축(axis)을 배치 차원(batch dimension)으로 사용한다는 뜻입니다. 수동 반복문도, 모양 재구성(reshape)도, 배치 차원을 함수 곳곳에 전달하는 작업도 필요 없습니다.
vmap은 단순한 문법 설탕(syntactic sugar)이 아닙니다. 융합된 벡터화 코드(fused vectorized code)를 만들어 Python 반복문보다 10~100배 빠를 수 있습니다. jit, grad와도 조합됩니다.
per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))
한 줄로 예시별 기울기(per-example gradient)를 계산합니다. PyTorch에서는 편법(hack) 없이는 구현하기 어려운 패턴입니다.
pmap: 여러 장치로 데이터 병렬화
parallel_step = jax.pmap(train_step, axis_name='devices')
pmap은 함수를 사용 가능한 모든 GPU/TPU에 복제하고 배치를 나눕니다. 함수 안에서는 jax.lax.pmean, jax.lax.psum으로 장치(device) 간 기울기를 동기화합니다.
Google은 pmap과 그 후속 도구인 shard_map 계열을 사용해 Gemini를 수천 개의 TPU 칩(chip)에 걸쳐 학습합니다. 프로그래밍 모델(programming model)은 단일 장치 버전을 작성하고 pmap으로 감싸는 방식입니다.
파이트리(Pytree)
JAX는 리스트, 튜플, 딕셔너리(dict), 배열이 중첩된 구조인 파이트리를 기본 자료구조로 다룹니다. 모델 매개변수도 파이트리입니다.
params = {
'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
'layer3': {'w': jnp.zeros((128, 10)), 'b': jnp.zeros(10)},
}
grad, jit, vmap 같은 JAX 변환은 파이트리를 순회하는 방법을 알고 있습니다. 옵티마이저 갱신(update)도 트리 전체에 적용됩니다.
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
.parameters() 메서드나 매개변수 등록(parameter registration)은 없습니다. 트리 구조 자체가 모델입니다.
함수형 대 객체 지향(Functional vs Object-Oriented)
PyTorch는 상태를 객체(object) 안에 저장합니다.
class Model(nn.Module):
def __init__(self):
self.linear = nn.Linear(784, 10)
def forward(self, x):
return self.linear(x)
JAX는 명시적 상태를 받는 순수 함수를 사용합니다.
def predict(params, x):
return jnp.dot(x, params['w']) + params['b']
params는 입력으로 들어오고, 아무것도 내부에 저장되거나 변경되지 않습니다. 그래서 함수는 테스트 가능(testable)하고 조합 가능(composable)하며 컴파일 가능(compilable)합니다. 대신 매개변수를 직접 관리하거나 Flax, Equinox 같은 라이브러리를 사용해야 합니다.
JAX 생태계
JAX는 기본 연산을 제공합니다. 라이브러리는 사용성을 보완합니다.
| 라이브러리 | 역할 | 스타일 |
|---|
| Flax(Google) | 신경망 계층(layer) | 명시적 상태를 가진 nn.Module |
| Equinox(Patrick Kidger) | 신경망 계층 | 파이트리 기반, Pythonic |
| Optax(DeepMind) | 옵티마이저와 학습률(LR) 스케줄 | 조합 가능한 기울기 변환(gradient transform) |
| Orbax(Google) | 체크포인팅(Checkpointing) | 파이트리 저장/복원 |
| CLU(Google) | 지표(metric)와 로깅 | 학습 루프 유틸리티 |
Optax는 표준 옵티마이저 라이브러리입니다. Adam, SGD, 클리핑(clipping) 같은 기울기 변환을 매개변수 갱신과 분리하므로 조합이 쉽습니다.
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=1e-3),
)
JAX와 PyTorch를 선택하는 기준
| 기준 | JAX | PyTorch |
|---|
| TPU 지원 | 일급(first-class) 지원 | 커뮤니티 유지(torch_xla) |
| GPU 지원 | 좋음(XLA를 통한 CUDA) | 최고 수준(native CUDA) |
| 디버깅 | 어려움(추적 + 컴파일) | 쉬움(즉시 실행, 줄 단위) |
| 생태계 | 연구 중심(Flax, Equinox) | 매우 큼(HuggingFace, torchvision 등) |
| 채용 | 좁은 시장(Google/DeepMind/Anthropic) | 주류(mainstream) |
| 대규모 학습 | 강함(XLA, pmap, mesh) | 강함(FSDP, DeepSpeed) |
| 프로토타이핑 속도 | 느릴 수 있음(함수형 오버헤드) | 빠름(변경하며 바로 진행) |
| 운영 환경 추론 | TensorFlow Serving, Vertex AI | TorchServe, Triton, ONNX |
| 사용하는 조직 | DeepMind(Gemini), Anthropic(Claude) | Meta(Llama), OpenAI(GPT), Stability AI |
솔직한 기준은 이렇습니다. 특별한 이유가 없다면 PyTorch를 사용합니다. JAX를 선택할 이유는 TPU 접근성, 예시별 기울기 계산, 대규모 다중 장치 학습, Google/DeepMind/Anthropic 스타일의 작업 환경 등입니다.
JAX의 난수(Random Numbers)
JAX에는 전역 난수 상태(global random state)가 없습니다. 모든 난수 연산은 명시적 PRNG 키가 필요합니다.
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))
처음에는 번거롭지만, 장치와 컴파일을 넘나들 때 재현성(Reproducibility)을 보장하는 데 도움이 됩니다. 여러 GPU 환경에서 torch.manual_seed만으로는 보장하기 어려운 성질입니다.
만들어 보기
Step 1: 설정과 데이터
JAX와 Optax로 MNIST에 3계층 MLP를 학습합니다. 입력 784개, 은닉 계층(hidden layer)은 각각 256개와 128개의 뉴런(neuron), 출력 클래스(output class)는 10개입니다.
import jax
import jax.numpy as jnp
from jax import random
import optax
def get_mnist_data():
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data.astype('float32') / 255.0
y = mnist.target.astype('int')
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
return X_train, y_train, X_test, y_test
Step 2: 매개변수(Parameter) 초기화
클래스 없이 파이트리를 반환하는 함수만 둡니다.
def init_params(key):
k1, k2, k3 = random.split(key, 3)
scale1 = jnp.sqrt(2.0 / 784)
scale2 = jnp.sqrt(2.0 / 256)
scale3 = jnp.sqrt(2.0 / 128)
params = {
'layer1': {'w': scale1 * random.normal(k1, (784, 256)), 'b': jnp.zeros(256)},
'layer2': {'w': scale2 * random.normal(k2, (256, 128)), 'b': jnp.zeros(128)},
'layer3': {'w': scale3 * random.normal(k3, (128, 10)), 'b': jnp.zeros(10)},
}
return params
He 초기화(He initialization)를 수동으로 적용합니다. 하나의 시드(seed)에서 세 개의 PRNG 키를 나눕니다. 모든 가중치(weight)는 중첩된 딕셔너리 안의 불변 배열입니다.
Step 3: 순전파(Forward Pass)
def forward(params, x):
x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
x = jax.nn.relu(x)
x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
x = jax.nn.relu(x)
x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
return x
def loss_fn(params, x, y):
logits = forward(params, x)
one_hot = jax.nn.one_hot(y, 10)
return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))
순수 함수입니다. params가 들어오고 예측(prediction)이 나갑니다. self도, 저장된 상태도 없습니다. loss_fn은 교차 엔트로피(cross-entropy)를 직접 계산합니다.
Step 4: JIT 컴파일된 학습 단계(Training Step)
@jax.jit
def train_step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@jax.jit
def accuracy(params, x, y):
logits = forward(params, x)
preds = jnp.argmax(logits, axis=-1)
return jnp.mean(preds == y)
jax.value_and_grad는 한 번의 패스(pass)에서 손실 값(loss value)과 기울기를 모두 반환합니다. @jax.jit 데코레이터(decorator)는 두 함수를 XLA로 컴파일합니다. 첫 호출 이후에는 Python을 거치지 않고 각 학습 단계가 실행됩니다.
Step 5: 학습 루프(Training Loop)
optimizer = optax.adam(learning_rate=1e-3)
X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)
key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)
batch_size = 128
n_epochs = 10
for epoch in range(n_epochs):
key, subkey = random.split(key)
perm = random.permutation(subkey, len(X_train))
X_shuffled = X_train[perm]
y_shuffled = y_train[perm]
epoch_loss = 0.0
n_batches = len(X_train) // batch_size
for i in range(n_batches):
start = i * batch_size
xb = X_shuffled[start:start + batch_size]
yb = y_shuffled[start:start + batch_size]
params, opt_state, loss = train_step(params, opt_state, xb, yb)
epoch_loss += loss
train_acc = accuracy(params, X_train[:5000], y_train[:5000])
test_acc = accuracy(params, X_test, y_test)
print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")
10 에포크(epoch) 후 테스트 정확도(test accuracy)는 약 97%입니다. 첫 에포크는 JIT 컴파일 때문에 느립니다. 2~10 에포크는 빠릅니다.
여기에는 .zero_grad(), .backward(), .step()이 없습니다. 갱신 전체가 조합된(composed) 하나의 함수 호출입니다. 기울기 계산, Adam 변환, 매개변수 적용이 모두 train_step 안에서 일어납니다.
사용하기
Flax: Google 표준
Flax는 가장 널리 쓰이는 JAX 신경망 라이브러리입니다. nn.Module을 다시 제공하지만 상태 관리는 명시적입니다.
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)
구조는 PyTorch와 비슷하지만 params가 모델과 분리되어 있습니다. model.init()이 매개변수를 만들고, model.apply(params, x)가 순전파를 실행합니다. 모델 객체 자체에는 상태가 없습니다.
Equinox: 파이썬다운(Pythonic) 대안
Equinox는 모델을 파이트리로 표현합니다.
import equinox as eqx
model = eqx.nn.MLP(
in_size=784, out_size=10, width_size=256, depth=2,
activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)
모델 자체가 파이트리이므로 .apply()가 필요 없습니다. 매개변수는 모델의 잎(leaf)입니다. JAX가 생각하는 방식에 더 가까운 형태입니다.
Optax: 조합 가능한 옵티마이저
Optax는 기울기 변환과 갱신을 분리합니다.
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=1e-3,
warmup_steps=1000, decay_steps=50000
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=schedule, weight_decay=0.01),
)
기울기 클리핑(Gradient clipping), 학습률 워밍업(learning rate warmup), 가중치 감쇠(weight decay)를 체인(chain)으로 조합합니다. 각 변환은 기울기를 받고 수정한 뒤 다음 변환으로 넘깁니다. 하나의 거대한 옵티마이저 클래스에 모든 동작을 몰아넣지 않습니다.
산출물 만들기
이 lesson의 최종 산출물은 다음 두 가지입니다.
outputs/prompt-jax-optimizer.md: JAX 옵티마이저 구성(optimizer configuration)을 선택하는 prompt
outputs/skill-jax-patterns.md: JAX의 함수형 패턴(functional pattern)을 정리한 skill
설치와 운영 참고
pip install jax jaxlib optax flax
pip install jax[cuda12]
TPU(Google Cloud)에서는 다음처럼 설치합니다.
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
성능 관련 주의점은 다음과 같습니다.
- 첫 JIT 호출은 컴파일 때문에 느립니다. 벤치마크(benchmark) 전 워밍업(warm up)을 수행합니다.
- JIT 안에서 JAX 배열을 Python 반복문으로 순회하지 않습니다.
jax.lax.scan, jax.lax.fori_loop를 사용합니다.
- JIT 내부에서 출력할 때는 일반
print 대신 jax.debug.print()를 사용합니다.
jax.profiler나 TensorBoard로 프로파일링(profile)합니다. XLA 컴파일은 병목을 숨길 수 있습니다.
- JAX는 기본적으로 GPU 메모리의 75%를 사전 할당(pre-allocate)합니다. 필요하면
XLA_PYTHON_CLIENT_PREALLOCATE=false를 설정합니다.
체크포인팅은 Orbax로 파이트리를 저장하고 복원합니다.
import orbax.checkpoint as ocp
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save('/tmp/model', params)
restored = checkpointer.restore('/tmp/model')
연습문제
- (쉬움) MLP에 드롭아웃(dropout)을 추가합니다. JAX의 드롭아웃은 PRNG 키가 필요하므로 순전파에 키를 전달하고 드롭아웃 계층마다 분할(split)합니다. 적용 전후 테스트 정확도를 비교합니다.
- (중간)
jax.vmap으로 MNIST 32개 배치의 예시별 기울기를 계산합니다. 각 예시의 기울기 노름(gradient norm)을 구합니다. 어떤 예시의 기울기가 가장 크고, 왜 그런지 분석합니다.
- (중간) 수동으로 작성한 순전파 함수를 임의의 계층 수에 대응하는
mlp_forward(params, x)로 바꿉니다. jax.tree.leaves로 깊이(depth)를 자동 판단합니다.
- (중간)
@jax.jit 적용 전후의 학습 단계를 벤치마크합니다. 각각 100단계를 측정합니다. 사용 중인 하드웨어에서 속도 향상(speedup)과 첫 호출의 컴파일 부가 비용이 어느 정도인지 확인합니다.
- (어려움)
optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3))로 기울기 클리핑을 구현합니다. 클리핑 전후로 학습하고 기울기 노름을 시각화(plot)합니다.
핵심 용어
| 용어 | 흔한 설명 | 실제 의미 |
|---|
| XLA | JAX를 빠르게 만드는 것 | 계산 그래프에서 연산을 융합하고 GPU/TPU 커널을 생성하는 컴파일러 |
| JIT | 적시 컴파일(Just-in-time compilation) | 첫 호출에서 함수를 추적하고 XLA로 컴파일한 뒤 이후 호출에서 컴파일된 버전을 실행하는 방식 |
| 순수 함수(Pure Function) | 부작용이 없음 | 출력이 오직 입력에만 의존하며 전역 상태, 변경, 명시적 키 없는 난수성이 없는 함수 |
vmap | 자동 배치(auto-batching) | 하나의 예시를 처리하는 함수를 배치를 처리하는 함수로 변환하는 변환 |
pmap | 자동 병렬화(auto-parallelism) | 함수를 여러 장치에 복제하고 입력 배치를 분할하는 변환 |
| 파이트리(Pytree) | 배열의 중첩 딕셔너리 | JAX가 순회하고 변환할 수 있는 리스트, 튜플, 딕셔너리, 배열의 중첩 구조 |
| 추적(tracing) | 계산을 기록하는 것 | 실제 값을 계산하기보다 추상값(abstract value)으로 함수를 실행해 계산 그래프를 만드는 과정 |
| 함수형 자동 미분(functional autodiff) | 함수의 기울기 | 텐서에 기울기 저장소를 붙이는 대신 함수를 변환해 도함수를 계산하는 방식 |
| Optax | JAX 옵티마이저 라이브러리 | Adam, SGD, 클리핑, 스케줄링을 체인으로 조합하는 기울기 변환 라이브러리 |
| Flax | JAX의 nn.Module | 명시적 상태를 유지하면서 계층 추상화(layer abstraction)를 제공하는 Google의 JAX 신경망 라이브러리 |
더 읽을거리