전이 학습(Transfer Learning)과 미세 조정(Fine-Tuning)

누군가는 신경망(Network)이 윤곽선(edge), 질감(texture), 사물 부위(object part)를 보도록 백만 GPU-시간(GPU-hour)을 썼습니다. 직접 학습하기 전에 그 특징(feature)을 빌려와야 합니다.

유형: Build 언어: Python 선수 학습: Phase 4 Lesson 03 (CNNs), Phase 4 Lesson 04 (Image Classification) 소요 시간: 약 75분

학습 목표

  • 특징 추출(feature extraction)과 미세 조정(fine-tuning)을 구분하고, 데이터셋 크기(dataset size), 도메인 거리(domain distance), 계산 예산(compute budget)에 따라 적절한 방법을 고릅니다.
  • 사전 학습된 백본(pretrained backbone)을 불러오고 분류기 헤드(classifier head)를 교체한 뒤, 헤드만 학습해 기준선(baseline)을 만듭니다.
  • 변별적 학습률(discriminative learning rate)로 층(layer)을 점진적으로 해제(unfreeze)하여, 앞쪽의 일반적 특징(early generic feature)은 작게, 뒤쪽의 과제 특화 특징(late task-specific feature)은 크게 갱신(update)합니다.
  • 학습률이 너무 높을 때 발생하는 특징 표류(feature drift), 작은 데이터셋에서의 BatchNorm 통계량 붕괴(BN statistics collapse), 그리고 치명적 망각(catastrophic forgetting)을 진단합니다.

문제

ImageNet에서 ResNet-50을 학습하려면 약 2,000 GPU-시간이 듭니다. 대부분의 팀은 모든 과제(task)마다 그런 예산을 쓰지 않습니다. 실제 배포 현장에서는 사전 학습된 백본에 새 헤드를 붙이고, 수백~수천 장의 과제 특화 이미지(task-specific image)로 학습하는 경우가 많습니다.

이는 단순한 지름길(shortcut)이 아닙니다. ImageNet으로 학습된 합성곱 신경망(CNN)의 첫 번째 합성곱 블록(conv block)은 윤곽선과 가버 유사 필터(Gabor-like filter)를 학습합니다. 다음 블록은 질감과 단순한 모티프(motif)를, 중간 블록은 사물 부위를, 마지막 블록은 ImageNet의 1,000개 범주(category)와 닮은 조합을 학습합니다. 이 계층 구조(hierarchy)의 앞 90%는 의료 영상(medical imaging), 산업 검사(industrial inspection), 위성 데이터(satellite data) 같은 거의 모든 영역에도 큰 변경 없이 전이됩니다. 자연에는 윤곽선과 질감의 어휘(vocabulary)가 제한적이기 때문입니다. 실제로 직접 학습해야 하는 것은 마지막 10%뿐입니다.

전이 학습을 제대로 하려면 세 가지 버그(bug)를 피해야 합니다. 학습률(learning rate)이 너무 높아 사전 학습된 특징을 망가뜨리는 경우, 너무 많이 동결(freeze)해서 모델에 정보가 부족한 경우, 그리고 BatchNorm의 이동 통계량(running statistics)이 작은 목표 데이터셋 쪽으로 표류(drift)하는 경우입니다.

사전 테스트

2문제 · 이 강의를 시작하기 전에 얼마나 알고 있는지 확인해보세요

1.ImageNet 분포에 가까운 새 과제에 라벨링된 이미지 500장이 있습니다. 어떤 방식이 가장 적절합니까?

2.ImageNet에는 X-선 영상이 없는데도, ImageNet으로 사전 학습된 신경망의 앞쪽 합성곱 층이 의료 영상에 잘 전이되는 이유는 무엇입니까?

0/2 답변 완료

개념

특징 추출 대 미세 조정 (Feature extraction vs fine-tuning)

두 방식(regime)은 사전 학습된 특징을 얼마나 신뢰하는지, 그리고 데이터가 얼마나 있는지에 따라 선택합니다.

flowchart TB
    subgraph FE["Feature extraction — backbone frozen"]
        FE1["Pretrained backbone<br/>(no gradient)"] --> FE2["New head<br/>(trained)"]
    end
    subgraph FT["Fine-tuning — end-to-end"]
        FT1["Pretrained backbone<br/>(tiny LR)"] --> FT2["New head<br/>(normal LR)"]
    end

    style FE1 fill:#e5e7eb,stroke:#6b7280
    style FE2 fill:#dcfce7,stroke:#16a34a
    style FT1 fill:#fef3c7,stroke:#d97706
    style FT2 fill:#dcfce7,stroke:#16a34a
데이터셋 크기도메인 거리권장 방법
이미지 1천 장 미만ImageNet에 가까움백본 동결, 헤드만 학습
1천~1만 장가까움앞쪽 2~3개 단계(stage) 동결, 나머지 미세 조정
1만~10만 장임의변별적 학습률로 종단간(end-to-end) 미세 조정
10만 장 이상멀리 떨어짐전체 미세 조정. 도메인이 충분히 멀면 처음부터 학습(scratch training)도 고려

"ImageNet에 가깝다"는 것은 사물에 가까운 내용을 담은 자연 RGB 사진(natural RGB photo)을 뜻합니다. CT 촬영(CT scan), 위성 이미지(satellite imagery), 현미경 이미지(microscopy)는 먼 도메인(far domain)에 해당합니다. 특징은 여전히 도움이 되지만, 더 많은 층을 적응(adapt)시켜야 합니다.

동결이 통하는 이유

합성곱 신경망이 ImageNet에서 학습한 특징은 1,000개 범주에만 특화된 것이 아닙니다. 윤곽선의 방향(edge orientation), 질감, 명암 패턴(contrast pattern), 형태 원시 요소(shape primitive)처럼 자연 이미지의 통계량에 특화되어 있습니다. 이 통계량은 대부분의 시각 도메인(visual domain)에서 어느 정도 공유됩니다. 그래서 동결된 백본 위에 선형 헤드(linear head)만 학습해도 CIFAR-10에서 높은 정확도(accuracy)가 나오는 것입니다.

변별적 학습률

층을 해제할 때 앞쪽 층은 뒤쪽 층보다 느리게 학습해야 합니다. 앞쪽 층은 보존하고 싶은 일반적 특징을 담고 있고, 뒤쪽 층은 새로운 과제에 맞춰 많이 움직여야 합니다.

stage 0 (stem + first group): lr = base_lr / 100
stage 1:                      lr = base_lr / 10
stage 2:                      lr = base_lr / 3
stage 3 (last backbone):      lr = base_lr
head:                         lr = base_lr 또는 그보다 약간 큼

PyTorch에서는 옵티마이저(optimizer)에 매개변수 그룹(parameter group) 리스트를 넘기면 됩니다. 하나의 모델에 여러 학습률을 줄 수 있습니다.

BatchNorm 문제

BatchNorm 층은 ImageNet에서 계산된 running_mean, running_var 버퍼(buffer)를 가지고 있습니다. 목표 과제의 픽셀 분포(pixel distribution)가 다르면 이 버퍼가 맞지 않습니다. 선택지는 다음 세 가지입니다.

  1. BatchNorm을 학습 모드(train mode)로 미세 조정. 데이터셋이 중간 크기 이상(예: 5천 장 이상)이면 이동 통계량을 목표 데이터에 맞춥니다.
  2. BatchNorm을 평가 모드(eval mode)로 동결. 데이터셋이 작아 이동 평균(moving average)이 잡음(noisy)이 많을 때 ImageNet의 통계량을 유지합니다.
  3. GroupNorm으로 교체. 이동 평균 문제 자체를 제거합니다. GPU당 배치 크기(batch size)가 작은 객체 탐지(detection)나 분할(segmentation)에서 자주 사용합니다.

잘못 선택하면 정확도가 5~15%p 떨어질 수 있습니다.

헤드 설계 (Head design)

분류기 헤드는 보통 1~3개의 선형 층(linear layer)과 선택적인 드롭아웃(dropout)으로 구성됩니다. torchvision 백본의 기본 헤드(default head)를 교체합니다.

backbone.fc = nn.Linear(backbone.fc.in_features, num_classes)       # ResNet
backbone.classifier[1] = nn.Linear(..., num_classes)                # EfficientNet, MobileNet
backbone.heads.head = nn.Linear(..., num_classes)                   # torchvision ViT

작은 데이터셋에서는 단일 선형 층 하나로 충분합니다. 과제 분포(task distribution)가 백본의 학습 분포와 멀리 떨어져 있으면 은닉 층(hidden layer)과 드롭아웃을 추가할 수 있습니다.

층별 학습률 감쇠 (Layer-wise LR decay)

최근의 미세 조정에서는 층마다 조금씩 다른 학습률을 주기도 합니다.

lr_layer_k = base_lr * decay^(L - k)

감쇠율(decay)이 0.75이고 트랜스포머 블록(transformer block)이 12개라면, 첫 번째 블록은 헤드 학습률의 약 0.75^11 ≈ 0.04배로 학습됩니다. 합성곱 신경망에서는 단계별로 묶은 학습률(stage-grouped LR)만으로도 충분한 경우가 많습니다.

무엇을 평가할까

전이 학습 실행(run)에서는 처음부터 학습하는 실행에서는 잘 보지 않는 두 숫자를 추적합니다.

  • 사전 학습만의 정확도 (Pretrained-only accuracy): 백본을 동결한 상태에서 헤드만 학습했을 때의 정확도. 바닥값(floor)입니다.
  • 미세 조정 정확도 (Fine-tuned accuracy): 종단간 학습 이후의 정확도. 천장값(ceiling)입니다.

미세 조정 정확도가 사전 학습만의 정확도보다 낮으면 학습률이나 BatchNorm 관련 버그를 의심해야 합니다. 두 값을 항상 함께 출력합니다.

만들어 보기

Step 1: 사전 학습된 백본 불러오기

import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(backbone)
print()
print("classifier head:", backbone.fc)
print("feature dim:", backbone.fc.in_features)

ResNet18에는 줄기(stem), layer1..layer4, 그리고 fc 헤드가 있습니다. torchvision의 분류용 백본은 대체로 비슷한 구조를 갖습니다.

Step 2: 특징 추출 — 전체를 동결하고 헤드 교체하기

def make_feature_extractor(num_classes=10):
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    for p in model.parameters():
        p.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model = make_feature_extractor(num_classes=10)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"trainable: {trainable:>10,}")
print(f"frozen:    {frozen:>10,}")

이제 model.fc만 학습 가능하고, 백본은 동결된 특징 추출기(frozen feature extractor)가 됩니다.

Step 3: 변별적 미세 조정

def discriminative_param_groups(model, base_lr=1e-3, decay=0.3):
    stages = [["conv1", "bn1"], ["layer1"], ["layer2"], ["layer3"], ["layer4"], ["fc"]]
    groups = []
    for i, names in enumerate(stages):
        lr = base_lr * (decay ** (len(stages) - 1 - i))
        params = [p for n, p in model.named_parameters()
                  if any(n.startswith(k) for k in names)]
        if params:
            groups.append({"params": params, "lr": lr, "name": "_".join(names)})
    return groups

model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 10)
for p in model.parameters():
    p.requires_grad = True

groups = discriminative_param_groups(model)
for g in groups:
    print(f"{g['name']:>10s}  lr={g['lr']:.2e}  params={sum(p.numel() for p in g['params']):>8,}")

decay=0.3이면 fcbase_lr로, layer40.3 * base_lr로, conv10.3^5 * base_lr로 학습됩니다. 극단적으로 들리지만 경험적으로 잘 작동합니다.

Step 4: BatchNorm 처리

def freeze_bn_stats(model):
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.eval()
            for p in m.parameters():
                p.requires_grad = False
    return model

에폭(epoch) 시작 때 model.train()을 호출한 뒤, BatchNorm 층만 다시 평가 모드로 되돌릴 때 사용합니다.

Step 5: 최소한의 미세 조정 루프

fine_tune은 변별적 매개변수 그룹, SGD 옵티마이저, 코사인 스케줄러(cosine scheduler)를 사용합니다. 매 에폭마다 학습 루프(training loop)를 돌고 검증 정확도(validation accuracy)를 출력합니다.

from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

def fine_tune(model, train_loader, val_loader, device, epochs=5, base_lr=1e-3, freeze_bn=False):
    model = model.to(device)
    groups = discriminative_param_groups(model, base_lr=base_lr)
    optimizer = SGD(groups, momentum=0.9, weight_decay=1e-4, nesterov=True)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        model.train()
        if freeze_bn:
            freeze_bn_stats(model)
        tr_loss, tr_correct, tr_total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits, y, label_smoothing=0.1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            tr_loss += loss.item() * x.size(0)
            tr_total += x.size(0)
            tr_correct += (logits.argmax(-1) == y).sum().item()
        scheduler.step()

        model.eval()
        va_total, va_correct = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x).argmax(-1)
                va_total += x.size(0)
                va_correct += (pred == y).sum().item()
        print(f"epoch {epoch}  train {tr_loss/tr_total:.3f}/{tr_correct/tr_total:.3f}  "
              f"val {va_correct/va_total:.3f}")
    return model

CIFAR-10에서 위 방법대로 5 에폭 학습하면 ResNet18-IMAGENET1K_V1은 약 70%의 영-숏 선형 탐사 정확도(zero-shot linear-probe accuracy)에서 약 93%의 미세 조정 정확도까지 올라갑니다. 헤드만 학습하면 백본을 건드리지 않은 상태에서 약 86% 부근에 머무를(plateau) 수 있습니다.

Step 6: 점진적 해제 (Progressive unfreezing)

마지막 단계부터 앞쪽으로 하나씩 해제하는 스케줄(schedule)입니다. 에폭마다 학습 가능한 매개변수 집합이 바뀌면 옵티마이저를 다시 만들어야 합니다. 그렇지 않으면 동결된 매개변수의 캐싱된 모멘트(cached moment)가 남아 혼란을 만들 수 있습니다.

def progressive_unfreeze_schedule(model):
    stages = ["layer4", "layer3", "layer2", "layer1"]
    yielded = set()

    def start():
        for p in model.parameters():
            p.requires_grad = False
        for p in model.fc.parameters():
            p.requires_grad = True

    def unfreeze(epoch):
        if epoch < len(stages):
            name = stages[epoch]
            yielded.add(name)
            for n, p in model.named_parameters():
                if n.startswith(name):
                    p.requires_grad = True
            return name
        return None

    return start, unfreeze

첫 에폭 이전에 start()를 한 번 호출합니다. 각 에폭의 시작 때 unfreeze(epoch)를 호출합니다. 학습 가능한 매개변수 집합이 바뀔 때마다 옵티마이저를 다시 만들어야 합니다.

사용하기

대부분의 실무 과제는 torchvision.models와 몇 줄의 코드로 시작할 수 있습니다.

from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

상용 환경의 기본 선택으로는 timm.create_model("resnet50", pretrained=True, num_classes=10)도 많이 사용합니다. 트랜스포머 계열의 경우 transformers.AutoModelForImageClassification.from_pretrained(name, num_labels=N)로 ViT, BEiT, DeiT 같은 모델을 불러올 수 있습니다.

산출물 만들기

이 lesson의 최종 산출물은 다음 두 가지입니다.

  • outputs/prompt-fine-tune-planner.md: 데이터셋 크기, 도메인 거리, 계산 예산에 따라 특징 추출, 점진적 미세 조정, 종단간 미세 조정 가운데 적절한 방식을 선택하는 프롬프트
  • outputs/skill-freeze-inspector.md: PyTorch 모델의 학습 가능한 매개변수, BatchNorm 평가 모드 여부, 옵티마이저 커버리지를 점검하는 스킬

연습문제

  1. 쉬움: 같은 합성 CIFAR(synthetic-CIFAR) 데이터셋에서 ResNet18을 선형 탐사(linear probe; 백본 동결)와 전체 미세 조정(full fine-tune)으로 각각 학습합니다. 두 정확도를 나란히 보고합니다.
  2. 중간: 백본 단계에 base_lr=1e-1을 주는 버그를 의도적으로 넣습니다. 학습 손실(training loss)이 발산하는 모습을 보인 뒤 discriminative_param_groups로 복구합니다.
  3. 어려움: 의료 영상 데이터셋(CheXpert-small, PatchCamelyon, HAM10000 등)에서 ImageNet으로 사전 학습된 동결 백본, ImageNet 사전 학습 종단간 미세 조정, 처음부터 학습 세 가지 방식을 비교합니다. 정확도와 계산 비용(compute cost)을 보고하고, 처음부터 학습이 경쟁력을 갖게 되는 데이터셋 크기를 찾습니다.

핵심 용어

용어흔한 설명실제 의미
특징 추출 (Feature extraction)"동결하고 헤드만 학습"백본의 매개변수를 동결하고 새 분류기 헤드만 그래디언트(gradient)를 받는 방식
미세 조정 (Fine-tuning)"종단간 재학습"사전 학습된 매개변수를 작은 학습률로 과제에 맞게 갱신하는 방식
변별적 학습률 (Discriminative LR)"앞쪽 층은 작은 학습률"단계별 옵티마이저 매개변수 그룹에 서로 다른 학습률을 주는 방식
층별 학습률 감쇠 (Layer-wise LR decay)"부드러운 학습률 변화"층이 아래로 갈수록 감쇠율을 곱해 학습률을 작게 만드는 방식
치명적 망각 (Catastrophic forgetting)"ImageNet 특징을 잃음"학습률이 너무 높아, 새 과제의 신호를 학습하기 전에 사전 학습된 특징을 덮어쓰는 현상
BatchNorm 통계량 표류 (BN statistics drift)"이동 평균이 틀림"BatchNorm의 running_mean/running_var가 현재 과제 분포와 맞지 않아 정확도를 떨어뜨리는 현상
선형 탐사 (Linear probe)"동결된 백본 + 선형 헤드"사전 학습된 표현(representation) 위에 최적의 선형 분류기를 학습해 특징 품질을 평가하는 방법
치명적 붕괴 (Catastrophic collapse)"한 클래스만 예측"특징이 망가지고 헤드의 그래디언트가 안정되기 전에 모델이 무너지는(collapse) 현상

더 읽을거리

실습 코드

이 강의의 실습 코드 1개

main
Code

산출물

이 강의에서 생성된 프롬프트, 스킬, 코드 산출물 2개

skill-freeze-inspector

Report which parameters are trainable, which BatchNorm layers are in eval mode, and whether the optimizer is actually consuming the trainable parameters

Skill
prompt-fine-tune-planner

Pick feature extraction vs progressive vs end-to-end fine-tuning given dataset size, domain distance, and compute budget

Prompt

확인 문제

3문제 · 모두 맞추면 완료 표시가 가능합니다

1.변별적 학습률(discriminative learning rate)로 종단간 미세 조정할 때, 앞쪽 층이 뒤쪽 층보다 작은 학습률을 가져야 하는 이유는 무엇입니까?

2.10개 클래스의 의료 영상 데이터셋(그레이스케일 800장을 3채널로 복제)에 대해 ResNet을 미세 조정했더니 정확도가 10%(무작위 추측 수준)였습니다. 가장 가능성 높은 원인은 무엇입니까?

3.두 실행을 비교했습니다. (a) ImageNet으로 사전 학습된 동결 백본 위의 선형 탐사 정확도가 82%, (b) 종단간 미세 조정 정확도가 78%입니다. 어떤 결론을 내려야 합니까?

0/3 답변 완료