JAX 입문(Introduction to JAX)

PyTorch는 텐서(Tensor)를 변경합니다. TensorFlow는 그래프(Graph)를 만듭니다. JAX는 순수 함수(Pure Function)를 컴파일합니다. 마지막 차이가 딥러닝(Deep Learning)을 바라보는 방식을 바꿉니다.

유형: Build 언어: Python 선수 학습: Phase 03 Lessons 01-10, 기본 NumPy 소요 시간: 약 90분

학습 목표

  • JAX의 함수형 API(Functional API; jax.numpy, jax.grad, jax.jit, jax.vmap)를 사용해 순수 함수 기반 신경망 코드를 작성합니다.
  • PyTorch의 즉시 실행 변경(Eager Mutation) 모델과 JAX의 함수형 컴파일(Functional Compilation) 모델의 핵심 차이를 설명합니다.
  • JIT 컴파일(Just-in-time Compilation)과 vmap 벡터화(Vectorization)를 적용해 단순한 Python 반복문(loop)보다 빠른 학습 루프를 만듭니다.
  • JAX로 간단한 신경망(network)을 학습하고, PyTorch의 객체 지향 접근과 달리 상태(state)를 명시적으로 다루는 방식을 비교합니다.

문제

PyTorch로 신경망을 만드는 방법은 이미 알고 있습니다. nn.Module을 정의하고, .backward()를 호출하고, 옵티마이저(optimizer)를 한 단계(step) 진행합니다. 잘 동작하고 수많은 사람이 사용합니다.

하지만 PyTorch에는 설계상 제약이 있습니다. Python에서 연산(operation)을 하나씩 즉시(eager) 추적(trace)합니다. tensor + tensor마다 별도의 커널 실행(kernel launch)이 발생할 수 있고, 매 학습 단계(training step)마다 같은 Python 코드가 다시 해석됩니다. 작은 모델에서는 괜찮지만, 2,048개 TPU에 걸쳐 5,400억 매개변수(parameter) 모델을 학습해야 한다면 이 부가 비용(overhead)이 치명적입니다.

Google DeepMind는 Gemini를 JAX로 학습합니다. Anthropic도 Claude 학습에 JAX를 사용했습니다. 이들은 지구상에서 가장 큰 규모의 신경망 학습 실행(neural network training run)에 속합니다. JAX를 선택한 이유는 학습 루프(training loop)를 Python 호출의 나열이 아니라 컴파일 가능한 프로그램으로 보기 때문입니다.

JAX는 세 가지 능력을 가진 NumPy입니다. 자동 미분(Automatic Differentiation), XLA로의 JIT 컴파일(JIT Compilation), 자동 벡터화(Vectorization)입니다. 하나의 예시(example)를 처리하는 함수를 작성하면, JAX는 그 함수를 배치(batch) 처리, 기울기(gradient) 계산, 기계어(machine code) 컴파일, 여러 장치(multi-device) 실행으로 확장할 수 있게 해 줍니다.

사전 테스트

2문제 · 이 강의를 시작하기 전에 얼마나 알고 있는지 확인해보세요

1.PyTorch와 JAX의 근본적인 설계 차이는 무엇입니까?

2.jax.jit은 무엇을 합니까?

0/2 답변 완료

개념

JAX의 철학

JAX는 함수형 프레임워크(Functional Framework)입니다. 클래스(class)도, 변경 가능한 상태(mutable state)도, .backward() 메서드(method)도 중심이 아닙니다.

PyTorchJAX
상태를 가진 nn.Module 클래스순수 함수: f(params, x) -> y
loss.backward()jax.grad(loss_fn)(params, x, y)
즉시 실행(Eager Execution)XLA 기반 JIT 컴파일
for x in batch: 수동 반복문jax.vmap(f) 자동 벡터화
DataParallel / FSDPjax.pmap(f) 자동 병렬화(parallelism)
변경 가능한 model.parameters()불변(immutable) 배열의 파이트리(pytree)

이것은 취향 문제가 아니라 컴파일러(compiler) 제약입니다. JIT 컴파일은 순수 함수가 필요합니다. 같은 입력은 항상 같은 출력을 내야 하고, 부작용(side effect)이 없어야 합니다. 이 제약이 100배 속도 향상의 기반이 됩니다.

jax.numpy: 익숙한 표면

JAX는 가속기(accelerator) 위에서 동작하는 NumPy API를 제공합니다.

import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)

함수 이름, 브로드캐스팅(broadcasting) 규칙, 슬라이싱(slicing) 의미는 NumPy와 비슷합니다. 하지만 배열(array)은 GPU/TPU 위에 있을 수 있고, 모든 연산은 컴파일러가 추적할 수 있습니다.

중요한 차이는 JAX 배열이 불변(immutable)이라는 점입니다. a[0] = 5처럼 수정하지 않습니다. 대신 a = a.at[0].set(5)를 씁니다. 처음에는 어색하지만, 이런 불변성(immutability)이 grad, jit, vmap 같은 변환(transformation)을 조합 가능하게 만듭니다.

jax.grad: 함수형 자동 미분

PyTorch는 기울기를 텐서의 .grad에 붙입니다. JAX는 기울기를 함수에 붙입니다.

import jax

def f(x):
    return x ** 2

df = jax.grad(f)
df(3.0)

jax.grad는 함수를 받아 기울기를 계산하는 새 함수를 반환합니다. .backward() 호출도, 텐서에 저장되는 계산 그래프(computation graph)도 없습니다. 반환된 기울기 함수도 다시 호출하거나, 조합하거나, JIT 컴파일할 수 있는 또 하나의 함수일 뿐입니다.

d2f = jax.grad(jax.grad(f))
d2f(3.0)

2차 도함수(derivative), 3차 도함수, 야코비안(Jacobian), 헤시안(Hessian)을 모두 grad 조합으로 표현할 수 있습니다. PyTorch에서도 가능하지만 JAX에서는 이것이 핵심 토대입니다.

제약도 있습니다. grad는 순수 함수에 잘 맞습니다. 추적(tracing) 중 실행되는 print, 외부 상태(state) 변경, 명시적 키(key) 없는 난수 생성(random generation)은 피해야 합니다.

jit: XLA로 컴파일하기

@jax.jit
def train_step(params, x, y):
    loss = loss_fn(params, x, y)
    return loss

fast_step = jax.jit(train_step)

첫 호출에서 JAX는 함수를 추적합니다. 실제 계산을 수행하기보다 어떤 연산이 일어나는지 기록합니다. 그 추적 결과를 XLA(Accelerated Linear Algebra)에 넘기면 XLA가 연산 융합(operation fusion), 불필요한 메모리 복사(memory copy) 제거, 최적화된 기계어 생성을 수행합니다.

이후 호출에서는 Python을 건너뛰고 컴파일된 코드가 가속기에서 C++에 가까운 속도로 실행됩니다.

JIT가 도움이 되는 경우는 반복되는 학습 단계, 같은 모델의 추론(inference), 비슷한 모양(shape)의 입력으로 여러 번 호출되는 함수입니다. 반대로 값에 의존하는 Python 제어 흐름(control flow), 한 번만 실행하는 계산, 디버깅에는 불리할 수 있습니다. JIT 안에서는 if/else 대신 jax.lax.cond, 반복문 대신 jax.lax.scan 같은 기본 연산(primitive)을 사용해야 할 때가 있습니다.

vmap: 자동 벡터화

하나의 예시를 처리하는 함수를 작성합니다.

def predict(params, x):
    return jnp.dot(params['w'], x) + params['b']

vmap은 이 함수를 배치를 처리하는 함수로 끌어올립니다.

batch_predict = jax.vmap(predict, in_axes=(None, 0))

in_axes=(None, 0)params는 배치 차원으로 묶지 않고 공유하며, x의 0번 축(axis)을 배치 차원(batch dimension)으로 사용한다는 뜻입니다. 수동 반복문도, 모양 재구성(reshape)도, 배치 차원을 함수 곳곳에 전달하는 작업도 필요 없습니다.

vmap은 단순한 문법 설탕(syntactic sugar)이 아닙니다. 융합된 벡터화 코드(fused vectorized code)를 만들어 Python 반복문보다 10~100배 빠를 수 있습니다. jit, grad와도 조합됩니다.

per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))

한 줄로 예시별 기울기(per-example gradient)를 계산합니다. PyTorch에서는 편법(hack) 없이는 구현하기 어려운 패턴입니다.

pmap: 여러 장치로 데이터 병렬화

parallel_step = jax.pmap(train_step, axis_name='devices')

pmap은 함수를 사용 가능한 모든 GPU/TPU에 복제하고 배치를 나눕니다. 함수 안에서는 jax.lax.pmean, jax.lax.psum으로 장치(device) 간 기울기를 동기화합니다.

Google은 pmap과 그 후속 도구인 shard_map 계열을 사용해 Gemini를 수천 개의 TPU 칩(chip)에 걸쳐 학습합니다. 프로그래밍 모델(programming model)은 단일 장치 버전을 작성하고 pmap으로 감싸는 방식입니다.

파이트리(Pytree)

JAX는 리스트, 튜플, 딕셔너리(dict), 배열이 중첩된 구조인 파이트리를 기본 자료구조로 다룹니다. 모델 매개변수도 파이트리입니다.

params = {
    'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
    'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
    'layer3': {'w': jnp.zeros((128, 10)),  'b': jnp.zeros(10)},
}

grad, jit, vmap 같은 JAX 변환은 파이트리를 순회하는 방법을 알고 있습니다. 옵티마이저 갱신(update)도 트리 전체에 적용됩니다.

params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

.parameters() 메서드나 매개변수 등록(parameter registration)은 없습니다. 트리 구조 자체가 모델입니다.

함수형 대 객체 지향(Functional vs Object-Oriented)

PyTorch는 상태를 객체(object) 안에 저장합니다.

class Model(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        return self.linear(x)

JAX는 명시적 상태를 받는 순수 함수를 사용합니다.

def predict(params, x):
    return jnp.dot(x, params['w']) + params['b']

params는 입력으로 들어오고, 아무것도 내부에 저장되거나 변경되지 않습니다. 그래서 함수는 테스트 가능(testable)하고 조합 가능(composable)하며 컴파일 가능(compilable)합니다. 대신 매개변수를 직접 관리하거나 Flax, Equinox 같은 라이브러리를 사용해야 합니다.

JAX 생태계

JAX는 기본 연산을 제공합니다. 라이브러리는 사용성을 보완합니다.

라이브러리역할스타일
Flax(Google)신경망 계층(layer)명시적 상태를 가진 nn.Module
Equinox(Patrick Kidger)신경망 계층파이트리 기반, Pythonic
Optax(DeepMind)옵티마이저와 학습률(LR) 스케줄조합 가능한 기울기 변환(gradient transform)
Orbax(Google)체크포인팅(Checkpointing)파이트리 저장/복원
CLU(Google)지표(metric)와 로깅학습 루프 유틸리티

Optax는 표준 옵티마이저 라이브러리입니다. Adam, SGD, 클리핑(clipping) 같은 기울기 변환을 매개변수 갱신과 분리하므로 조합이 쉽습니다.

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=1e-3),
)

JAX와 PyTorch를 선택하는 기준

기준JAXPyTorch
TPU 지원일급(first-class) 지원커뮤니티 유지(torch_xla)
GPU 지원좋음(XLA를 통한 CUDA)최고 수준(native CUDA)
디버깅어려움(추적 + 컴파일)쉬움(즉시 실행, 줄 단위)
생태계연구 중심(Flax, Equinox)매우 큼(HuggingFace, torchvision 등)
채용좁은 시장(Google/DeepMind/Anthropic)주류(mainstream)
대규모 학습강함(XLA, pmap, mesh)강함(FSDP, DeepSpeed)
프로토타이핑 속도느릴 수 있음(함수형 오버헤드)빠름(변경하며 바로 진행)
운영 환경 추론TensorFlow Serving, Vertex AITorchServe, Triton, ONNX
사용하는 조직DeepMind(Gemini), Anthropic(Claude)Meta(Llama), OpenAI(GPT), Stability AI

솔직한 기준은 이렇습니다. 특별한 이유가 없다면 PyTorch를 사용합니다. JAX를 선택할 이유는 TPU 접근성, 예시별 기울기 계산, 대규모 다중 장치 학습, Google/DeepMind/Anthropic 스타일의 작업 환경 등입니다.

JAX의 난수(Random Numbers)

JAX에는 전역 난수 상태(global random state)가 없습니다. 모든 난수 연산은 명시적 PRNG 키가 필요합니다.

key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))

처음에는 번거롭지만, 장치와 컴파일을 넘나들 때 재현성(Reproducibility)을 보장하는 데 도움이 됩니다. 여러 GPU 환경에서 torch.manual_seed만으로는 보장하기 어려운 성질입니다.

만들어 보기

Step 1: 설정과 데이터

JAX와 Optax로 MNIST에 3계층 MLP를 학습합니다. 입력 784개, 은닉 계층(hidden layer)은 각각 256개와 128개의 뉴런(neuron), 출력 클래스(output class)는 10개입니다.

import jax
import jax.numpy as jnp
from jax import random
import optax

def get_mnist_data():
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
    X = mnist.data.astype('float32') / 255.0
    y = mnist.target.astype('int')
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]
    return X_train, y_train, X_test, y_test

Step 2: 매개변수(Parameter) 초기화

클래스 없이 파이트리를 반환하는 함수만 둡니다.

def init_params(key):
    k1, k2, k3 = random.split(key, 3)
    scale1 = jnp.sqrt(2.0 / 784)
    scale2 = jnp.sqrt(2.0 / 256)
    scale3 = jnp.sqrt(2.0 / 128)
    params = {
        'layer1': {'w': scale1 * random.normal(k1, (784, 256)), 'b': jnp.zeros(256)},
        'layer2': {'w': scale2 * random.normal(k2, (256, 128)), 'b': jnp.zeros(128)},
        'layer3': {'w': scale3 * random.normal(k3, (128, 10)), 'b': jnp.zeros(10)},
    }
    return params

He 초기화(He initialization)를 수동으로 적용합니다. 하나의 시드(seed)에서 세 개의 PRNG 키를 나눕니다. 모든 가중치(weight)는 중첩된 딕셔너리 안의 불변 배열입니다.

Step 3: 순전파(Forward Pass)

def forward(params, x):
    x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
    return x

def loss_fn(params, x, y):
    logits = forward(params, x)
    one_hot = jax.nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))

순수 함수입니다. params가 들어오고 예측(prediction)이 나갑니다. self도, 저장된 상태도 없습니다. loss_fn은 교차 엔트로피(cross-entropy)를 직접 계산합니다.

Step 4: JIT 컴파일된 학습 단계(Training Step)

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def accuracy(params, x, y):
    logits = forward(params, x)
    preds = jnp.argmax(logits, axis=-1)
    return jnp.mean(preds == y)

jax.value_and_grad는 한 번의 패스(pass)에서 손실 값(loss value)과 기울기를 모두 반환합니다. @jax.jit 데코레이터(decorator)는 두 함수를 XLA로 컴파일합니다. 첫 호출 이후에는 Python을 거치지 않고 각 학습 단계가 실행됩니다.

Step 5: 학습 루프(Training Loop)

optimizer = optax.adam(learning_rate=1e-3)

X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)

key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)

batch_size = 128
n_epochs = 10

for epoch in range(n_epochs):
    key, subkey = random.split(key)
    perm = random.permutation(subkey, len(X_train))
    X_shuffled = X_train[perm]
    y_shuffled = y_train[perm]

    epoch_loss = 0.0
    n_batches = len(X_train) // batch_size
    for i in range(n_batches):
        start = i * batch_size
        xb = X_shuffled[start:start + batch_size]
        yb = y_shuffled[start:start + batch_size]
        params, opt_state, loss = train_step(params, opt_state, xb, yb)
        epoch_loss += loss

    train_acc = accuracy(params, X_train[:5000], y_train[:5000])
    test_acc = accuracy(params, X_test, y_test)
    print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
          f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

10 에포크(epoch) 후 테스트 정확도(test accuracy)는 약 97%입니다. 첫 에포크는 JIT 컴파일 때문에 느립니다. 2~10 에포크는 빠릅니다.

여기에는 .zero_grad(), .backward(), .step()이 없습니다. 갱신 전체가 조합된(composed) 하나의 함수 호출입니다. 기울기 계산, Adam 변환, 매개변수 적용이 모두 train_step 안에서 일어납니다.

사용하기

Flax: Google 표준

Flax는 가장 널리 쓰이는 JAX 신경망 라이브러리입니다. nn.Module을 다시 제공하지만 상태 관리는 명시적입니다.

import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)

구조는 PyTorch와 비슷하지만 params가 모델과 분리되어 있습니다. model.init()이 매개변수를 만들고, model.apply(params, x)가 순전파를 실행합니다. 모델 객체 자체에는 상태가 없습니다.

Equinox: 파이썬다운(Pythonic) 대안

Equinox는 모델을 파이트리로 표현합니다.

import equinox as eqx

model = eqx.nn.MLP(
    in_size=784, out_size=10, width_size=256, depth=2,
    activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)

모델 자체가 파이트리이므로 .apply()가 필요 없습니다. 매개변수는 모델의 잎(leaf)입니다. JAX가 생각하는 방식에 더 가까운 형태입니다.

Optax: 조합 가능한 옵티마이저

Optax는 기울기 변환과 갱신을 분리합니다.

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-3,
    warmup_steps=1000, decay_steps=50000
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.01),
)

기울기 클리핑(Gradient clipping), 학습률 워밍업(learning rate warmup), 가중치 감쇠(weight decay)를 체인(chain)으로 조합합니다. 각 변환은 기울기를 받고 수정한 뒤 다음 변환으로 넘깁니다. 하나의 거대한 옵티마이저 클래스에 모든 동작을 몰아넣지 않습니다.

산출물 만들기

이 lesson의 최종 산출물은 다음 두 가지입니다.

  • outputs/prompt-jax-optimizer.md: JAX 옵티마이저 구성(optimizer configuration)을 선택하는 prompt
  • outputs/skill-jax-patterns.md: JAX의 함수형 패턴(functional pattern)을 정리한 skill

설치와 운영 참고

pip install jax jaxlib optax flax
pip install jax[cuda12]  # GPU 지원

TPU(Google Cloud)에서는 다음처럼 설치합니다.

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

성능 관련 주의점은 다음과 같습니다.

  • 첫 JIT 호출은 컴파일 때문에 느립니다. 벤치마크(benchmark) 전 워밍업(warm up)을 수행합니다.
  • JIT 안에서 JAX 배열을 Python 반복문으로 순회하지 않습니다. jax.lax.scan, jax.lax.fori_loop를 사용합니다.
  • JIT 내부에서 출력할 때는 일반 print 대신 jax.debug.print()를 사용합니다.
  • jax.profiler나 TensorBoard로 프로파일링(profile)합니다. XLA 컴파일은 병목을 숨길 수 있습니다.
  • JAX는 기본적으로 GPU 메모리의 75%를 사전 할당(pre-allocate)합니다. 필요하면 XLA_PYTHON_CLIENT_PREALLOCATE=false를 설정합니다.

체크포인팅은 Orbax로 파이트리를 저장하고 복원합니다.

import orbax.checkpoint as ocp
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save('/tmp/model', params)
restored = checkpointer.restore('/tmp/model')

연습문제

  1. (쉬움) MLP에 드롭아웃(dropout)을 추가합니다. JAX의 드롭아웃은 PRNG 키가 필요하므로 순전파에 키를 전달하고 드롭아웃 계층마다 분할(split)합니다. 적용 전후 테스트 정확도를 비교합니다.
  2. (중간) jax.vmap으로 MNIST 32개 배치의 예시별 기울기를 계산합니다. 각 예시의 기울기 노름(gradient norm)을 구합니다. 어떤 예시의 기울기가 가장 크고, 왜 그런지 분석합니다.
  3. (중간) 수동으로 작성한 순전파 함수를 임의의 계층 수에 대응하는 mlp_forward(params, x)로 바꿉니다. jax.tree.leaves로 깊이(depth)를 자동 판단합니다.
  4. (중간) @jax.jit 적용 전후의 학습 단계를 벤치마크합니다. 각각 100단계를 측정합니다. 사용 중인 하드웨어에서 속도 향상(speedup)과 첫 호출의 컴파일 부가 비용이 어느 정도인지 확인합니다.
  5. (어려움) optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3))로 기울기 클리핑을 구현합니다. 클리핑 전후로 학습하고 기울기 노름을 시각화(plot)합니다.

핵심 용어

용어흔한 설명실제 의미
XLAJAX를 빠르게 만드는 것계산 그래프에서 연산을 융합하고 GPU/TPU 커널을 생성하는 컴파일러
JIT적시 컴파일(Just-in-time compilation)첫 호출에서 함수를 추적하고 XLA로 컴파일한 뒤 이후 호출에서 컴파일된 버전을 실행하는 방식
순수 함수(Pure Function)부작용이 없음출력이 오직 입력에만 의존하며 전역 상태, 변경, 명시적 키 없는 난수성이 없는 함수
vmap자동 배치(auto-batching)하나의 예시를 처리하는 함수를 배치를 처리하는 함수로 변환하는 변환
pmap자동 병렬화(auto-parallelism)함수를 여러 장치에 복제하고 입력 배치를 분할하는 변환
파이트리(Pytree)배열의 중첩 딕셔너리JAX가 순회하고 변환할 수 있는 리스트, 튜플, 딕셔너리, 배열의 중첩 구조
추적(tracing)계산을 기록하는 것실제 값을 계산하기보다 추상값(abstract value)으로 함수를 실행해 계산 그래프를 만드는 과정
함수형 자동 미분(functional autodiff)함수의 기울기텐서에 기울기 저장소를 붙이는 대신 함수를 변환해 도함수를 계산하는 방식
OptaxJAX 옵티마이저 라이브러리Adam, SGD, 클리핑, 스케줄링을 체인으로 조합하는 기울기 변환 라이브러리
FlaxJAX의 nn.Module명시적 상태를 유지하면서 계층 추상화(layer abstraction)를 제공하는 Google의 JAX 신경망 라이브러리

더 읽을거리

실습 코드

이 강의의 실습 코드 1개

jax intro
Code

산출물

이 강의에서 생성된 프롬프트, 스킬, 코드 산출물 2개

skill-jax-patterns

Functional programming patterns in JAX -- when and how to use grad, jit, vmap, and pmap

Skill
prompt-jax-optimizer

Choose and configure the right JAX/Optax optimizer for a given training scenario

Prompt

확인 문제

3문제 · 모두 맞추면 완료 표시가 가능합니다

1.jax.vmap은 무엇을 합니까?

2.JAX는 모델 상태(model state; 가중치)를 PyTorch와 어떻게 다르게 다룹니까?

3.언제 PyTorch보다 JAX를 선택하겠습니까?

0/3 답변 완료