Flamingo와 Gated Cross-Attention — Few-Shot VLM 구조
DeepMind의 Flamingo(2022)는 다른 모델들보다 먼저 두 가지를 보여주었습니다. 하나는 단일 모델이 이미지, 비디오, 텍스트가 임의로 섞인 시퀀스(arbitrarily interleaved sequence)를 처리할 수 있다는 점이고, 다른 하나는 시각 언어 모델(Vision-Language Model; VLM)도 문맥 내 학습(In-Context Learning)을 할 수 있다는 점입니다. 예를 들어 (이미지, 캡션) 예시 세 쌍을 few-shot 프롬프트로 주면, 모델은 추가적인 경사 업데이트(gradient step) 없이 새 이미지의 캡션을 생성합니다. 그 핵심 메커니즘은 동결된 거대 언어 모델(frozen LLM)의 기존 계층 사이에 삽입되는 게이트형 교차 어텐션(Gated Cross-Attention) 계층입니다. 이 계층에는 0에서 시작하는 학습 가능한 tanh 게이트가 달려 있어, 초기화 시점에는 LLM의 텍스트 능력이 그대로 보존됩니다. 이 lesson에서는 Flamingo의 퍼시버 리샘플러(Perceiver Resampler)와 게이트형 교차 어텐션 구조를 차례대로 살펴봅니다. 이 설계는 Gemini의 interleaved input과 Idefics2의 시각 토큰(visual token) 계열로 이어지는 조상 격의 구조입니다.
유형: Learn
언어: Python (표준 라이브러리, gated cross-attention + Perceiver resampler 데모)
선수 지식: Phase 12 · 03 (BLIP-2 Q-Former)
예상 시간: 약 120분
학습 목표
- 게이트형 교차 어텐션이
tanh(gate) = 0을 통해 초기화 시점에 동결된 LLM의 텍스트 능력을 어떻게 보존하는지 설명합니다.
- 퍼시버 리샘플러가 N개의 이미지 패치(patch)를 K개의 고정 잠재 질의(latent query)로 바꾸는 교차 어텐션 흐름을 따라갑니다.
- Flamingo가 이미지와 텍스트가 섞여 있는 시퀀스를 이미지 위치를 존중하는 인과 마스킹(causal masking)으로 처리하는 방식을 설명합니다.
- (이미지, 캡션) 예시 세 개 뒤에 질의 이미지(query image)를 붙이는 멀티모달 few-shot 프롬프트 구조를 재현합니다.
문제
BLIP-2는 32개의 시각 토큰을 동결된 LLM의 입력 계층으로 넣습니다. 프롬프트 하나에 이미지 한 장이 들어가는 상황에서는 잘 동작합니다. 하지만 "여기 이미지 A가 있으니 캡션을 붙여라, 여기 이미지 B가 있으니 캡션을 붙여라, 이제 이미지 C가 있으니 캡션을 붙여라"처럼 텍스트 사이에 여러 이미지를 섞어 넣고 싶다면 어떻게 해야 할까요? LLM의 셀프 어텐션(self-attention)은 이미지 토큰과 텍스트 토큰을 하나의 시퀀스 안에서 함께 처리해야 하고, 어떤 위치가 어떤 이미지를 볼 수 있는지 정하는 문제가 금방 까다로워집니다.
Flamingo의 답은 LLM의 입력 시퀀스를 전혀 바꾸지 않는 것입니다. 대신 기존 LLM 블록(block) 사이에 추가 교차 어텐션 계층을 삽입합니다. 텍스트 토큰은 늘 그렇듯 LLM의 인과 셀프 어텐션을 따라 흐릅니다. 그리고 일정한 LLM 블록마다, 텍스트 토큰이 새로 끼워 넣은 게이트 계층을 통해 이미지 특징(feature)에 교차 어텐션(cross-attend)합니다. 이 게이트는 0으로 초기화되므로 학습이 시작되는 시점에는 새 계층이 아무 일도 하지 않는 무동작(no-op) 상태입니다. 즉 모델은 사전학습된 LLM과 정확히 동일하게 동작합니다. 학습이 진행되면 게이트가 서서히 열리면서 시각 정보가 흘러들어오기 시작합니다.
Flamingo가 답한 두 번째 질문은 프롬프트마다 이미지 수가 0개, 1개, 또는 여러 개로 달라질 때 어떻게 처리하느냐입니다. 여기서 퍼시버 리샘플러가 사용됩니다. 퍼시버 리샘플러는 입력 패치 수가 얼마든지 받아들인 뒤, 고정된 개수의 시각 잠재 토큰(visual latent token)을 만들어내는 작은 교차 어텐션 모듈입니다. 그 결과 LLM의 교차 어텐션 계층은 프롬프트 안에 이미지가 몇 장 있든 같은 모양(shape)의 입력을 보게 됩니다.
개념
동결된 LLM
Flamingo는 동결된 친칠라 70B(Chinchilla 70B) LLM에서 시작합니다. 700억 개의 가중치(weight)는 전혀 건드리지 않습니다. 기존 텍스트 셀프 어텐션과 피드포워드 신경망(Feed-Forward Network; FFN)은 평소처럼 동작합니다.
Perceiver 리샘플러
프롬프트에 있는 각 이미지에 대해 ViT는 N개의 패치 토큰을 생성합니다. 퍼시버 리샘플러에는 K개의 고정된 학습 가능 잠재 벡터(latent)가 있습니다. Flamingo는 K=64를 사용합니다. 각 리샘플러 블록은 두 개의 하위 단계로 이루어집니다.
- 교차 어텐션(Cross-Attention): K개의 잠재 벡터가 N개의 패치 토큰 전체를 바라봅니다. 질의(Query, Q)는 잠재 벡터에서 오고, 키와 값(Key/Value, K/V)은 패치에서 옵니다.
- 셀프 어텐션과 FFN: 잠재 벡터 내부에서 셀프 어텐션과 FFN을 수행합니다.
여섯 개의 리샘플러 블록을 지난 뒤 출력은 ViT가 만든 패치 수와 무관하게 차원(dimension)이 1024인 64개의 시각 토큰이 됩니다. 224x224 이미지(196개 패치)와 480x480 이미지(900개 패치)는 모두 64개의 리샘플러 토큰으로 나옵니다.
비디오의 경우 리샘플러는 시간 방향으로 적용됩니다. 각 프레임(frame)의 패치가 64개의 잠재 벡터를 만들고, 시간적 위치 인코딩(temporal positional encoding)이 모델이 t=0과 t=N을 구분할 수 있게 해줍니다. 결과적으로 전체 비디오는 T * 64개의 시각 토큰으로 표현됩니다.
게이트형 교차 어텐션
동결된 LLM의 M개 계층마다 Flamingo는 새로운 게이트형 교차 어텐션 블록을 삽입합니다. Flamingo는 M=4를 사용합니다.
x_after_llm_block = llm_block(x_before)
cross = cross_attn(x_after, resampler_output)
gated = tanh(alpha) * cross + x_after
x_before_next_block = gated
alpha는 0으로 초기화되는 학습 가능한 스칼라(scalar)입니다.
tanh(0) = 0이므로 초기화 시점에는 게이트 분기(gated branch)가 전혀 기여하지 않습니다.
alpha가 0에서 멀어질수록 교차 어텐션의 기여가 부드럽게 커집니다.
- 잔차 연결(residual connection)이 있기 때문에 게이트가 완전히 열려도 LLM의 텍스트 표현을 덮어쓰지 않고, 기존 표현 위에 시각 정보를 더하기만 합니다.
이것이 Flamingo에서 가장 중요한 설계 선택입니다. 시각 조건부 정보(visual conditioning)는 더해지는 방식이고, 게이트로 조절되며, 초기화 시점에는 0입니다. 학습 0단계의 Flamingo는 텍스트만 입력으로 받았을 때 완전한 친칠라 70B와 동일하게 동작합니다.
Interleaved 입력을 위한 마스킹 교차 어텐션
<image A> caption A <image B> caption B <image C> ? 같은 프롬프트에서는 각 텍스트 토큰이 시퀀스에서 자신보다 앞에 나온 이미지만 볼 수 있어야 합니다. 교차 어텐션 마스크(mask)는 이 제약을 강제합니다. 위치 t의 텍스트 토큰은 이미지 인덱스가 i < i_t인 리샘플러 토큰만 볼 수 있습니다. 여기서 i_t는 위치 t보다 앞에 있는 가장 최근 이미지를 가리킵니다. "가장 최근의 이전 이미지만 본다"와 "앞선 모든 이미지를 본다"는 모두 가능한 선택지인데, Flamingo는 전자를 선택했습니다.
문맥 내 few-shot 학습 (In-context few-shot learning)
Flamingo 프롬프트는 다음과 같은 형태를 가집니다.
<image1> A photo of a cat. <image2> A photo of a dog. <image3> A photo of a
모델은 완성(completion) 패턴을 보고 "bird"(또는 image3가 실제로 보여주는 대상)를 출력합니다. 경사 업데이트는 전혀 없습니다. 동결된 LLM의 문맥 내 학습 능력이 게이트형 교차 어텐션을 통해 그대로 이어집니다. 이것이 Flamingo 논문의 핵심 메시지이며, Flamingo가 중요한 이유입니다.
학습 데이터
Flamingo는 세 종류의 데이터셋(dataset)으로 학습되었습니다.
- MultiModal MassiveWeb(M3W): 이미지와 텍스트가 섞여 있는 4,300만 개의 웹 페이지로, 읽기 순서(reading order)를 재구성합니다.
- Image-Text Pairs(ALIGN + LTIP): 44억 개의 (이미지, 텍스트) 쌍입니다.
- Video-Text Pairs(VTP): 2,700만 개의 짧은 비디오 클립(video clip)입니다.
OBELICS(2023)는 interleaved 웹 코퍼스(corpus)의 오픈 재현(open reproduction)입니다. Idefics, Idefics2, 그리고 대부분의 공개된 "Flamingo 계열" 모델은 OBELICS 계열 코퍼스로 학습됩니다.
OpenFlamingo와 Otter
OpenFlamingo(2023)는 Flamingo의 오픈 재현 모델입니다. 구조는 동일하게 퍼시버 리샘플러와 동결된 LLaMA 또는 MPT 위의 게이트형 교차 어텐션을 사용합니다. 체크포인트(checkpoint)는 3B, 4B, 9B 규모로 제공됩니다. 다만 더 작은 베이스 LLM과 더 적은 학습 데이터 때문에 품질은 Flamingo보다 뒤처집니다.
Otter(2023)는 OpenFlamingo를 바탕으로 MIMIC-IT이라는 멀티모달 지시 데이터셋(multimodal instruction dataset)에서 지시 튜닝(instruction tuning)을 수행한 모델입니다. 이를 통해 게이트형 교차 어텐션이 지시 따르기(instruction following) 작업에도 잘 동작한다는 점을 보여주었습니다.
후속 계열
- Idefics / Idefics2 / Idefics3: 허깅 페이스(Hugging Face)의 게이트형 교차 어텐션 계열입니다. 점점 단순해졌고, Idefics2는 리샘플러를 제거한 뒤 적응형 풀링(adaptive pooling)을 적용한 직접 패치 토큰(direct patch token) 방식을 사용합니다.
- Flamingo에서 Chameleon으로의 전환: 2024년 무렵 많은 팀이 조기 융합(early-fusion) 방식(Lesson 12.11)으로 옮겨갔습니다. 하지만 백본 동결(backbone freezing)이 필요한 운영(production) 환경에서는 Flamingo 스타일의 게이트형 교차 어텐션이 여전히 사용됩니다.
- Gemini의 interleaved 입력: 정확한 메커니즘은 비공개이지만, 개념적으로는 Flamingo가 보여준 interleaved 포맷의 유연성을 이어받은 것으로 볼 수 있습니다.
BLIP-2와의 비교
| 항목 | BLIP-2 | Flamingo |
|---|
| 시각 브리지(visual bridge) | 입력에서 한 번 Q-Former 사용 | M개 계층마다 게이트형 교차 어텐션 사용 |
| 시각 토큰 | 이미지당 32개 | 교차 어텐션 계층마다 이미지당 64개 |
| 동결된 LLM | 예 | 예 |
| Few-shot 문맥 내 학습 | 약함 | 강함 — 논문의 중심 주제 |
| Interleaved 입력 | 기본 지원 없음 | 예, 설계 목표 자체가 interleaved 입력 |
| 학습 데이터 | 1억 3천만 쌍 | 13억 쌍 + 4,300만 interleaved 페이지 |
| 학습 파라미터 수 | 1억 8,800만 개 학습 | 약 100억 개 학습(교차 어텐션 계층) |
| 컴퓨트 | A100 8개에서 며칠 | TPUv4 수천 개에서 몇 주 |
예산이 제한된 단일 이미지 시각 질의응답(Visual Question Answering; VQA)이라면 BLIP-2를 선택합니다. interleaved 입력, few-shot, 다중 이미지 추론(multi-image reasoning)이 필요하다면 Flamingo나 Idefics2 계열을 선택합니다.
사용해보기
code/main.py는 다음을 보여줍니다.
- 36개의 가짜 패치 토큰과 8개의 학습 가능 잠재 벡터를 사용하는 퍼시버 리샘플러입니다. 순수 Python으로 교차 어텐션을 구현합니다.
- 게이트형 교차 어텐션 단계입니다.
alpha = 0일 때 출력이 입력과 같아 LLM이 변하지 않음을 보여주고, alpha = 2.0일 때 시각 정보 기여가 섞여 들어오는 것을 보여줍니다.
(image 1) (text 1) (image 2) (text 2) 시퀀스에 대한 2차원 어텐션 마스크를 만드는 interleaved 마스크 빌더(builder)입니다.
산출물 만들기
이 lesson은 outputs/skill-gated-bridge-diagnostic.md를 만듭니다. 공개된 VLM 설정(resampler 사용 여부, 교차 어텐션 빈도, 게이트 방식)을 입력받아 Flamingo 계열 요소를 식별하고 동결 전략(freezing strategy)을 설명하는 스킬(skill)입니다. 파인튜닝(fine-tuning) 이후 텍스트 성능이 나빠진 원인을 디버깅할 때 특히 유용합니다. 흔한 답은 게이트가 너무 빠르게 또는 너무 크게 열렸다는 것입니다.
연습문제
-
쉬움: Flamingo-9B의 시각 파라미터 수를 계산합니다. 9B LLM + 1.4B 게이트형 교차 어텐션 계층 + 64M 리샘플러일 때, 전체 파라미터 중 학습되는 비율은 얼마입니까?
-
쉬움: PyTorch로 게이트형 잔차 y = tanh(alpha) * cross + x를 구현합니다. alpha=0이면 초기화 시점에 y==x가 정확히 성립한다는 것을 실험으로 보입니다.
-
중간: OpenFlamingo 논문 Section 3.2(arXiv:2308.01390)를 읽고, 각 프롬프트의 이미지 수가 서로 다른 배치(batch)에서 여러 이미지를 어떻게 처리하는지 확인합니다. 패딩(padding) 전략을 설명합니다.
-
어려움: Flamingo의 교차 어텐션 마스크는 왜 텍스트 토큰이 앞선 모든 이미지가 아니라 가장 최근의 이전 이미지만 보도록 만들었을까요? Flamingo 논문 Section 2.4를 읽고 절충안(tradeoff)을 설명합니다.
-
어려움: 문맥 내 few-shot을 실험합니다. 새로운 Flamingo 변종(variant)을 위해 "이미지 → 주요 객체(main object)의 색"을 다루는 예시 4개로 이루어진 프롬프트를 구성합니다. 예시 수를 0개에서 8개까지 바꿀 때 기대되는 정확도(accuracy) 패턴을 설명합니다.
핵심 용어
| 용어 | 흔한 설명 | 실제 의미 |
|---|
| 퍼시버 리샘플러(Perceiver resampler) | "고정 잠재 교차 어텐션" | 입력 패치 수가 달라도 K개의 고정 토큰을 만들어내는 모듈이다. |
| 게이트형 교차 어텐션(Gated cross-attention) | "tanh 게이트 브리지" | y = tanh(alpha)*cross + x 형태의 잔차 계층이다. alpha는 학습 가능하고 0으로 초기화된다. |
| Interleaved 입력 | "섞여 있는 시퀀스" | 이미지와 텍스트가 읽기 순서에 따라 자유롭게 섞여 있는 프롬프트 형식이다. |
| 동결된 LLM(Frozen LLM) | "LLM 경사 없음" | 텍스트 LLM의 가중치는 갱신하지 않고, 리샘플러와 교차 어텐션 계층만 학습한다. |
| Few-shot | "문맥 안 예시" | 프롬프트 안에 몇 개의 (이미지, 정답) 쌍을 제공하고, 파인튜닝 없이 모델이 일반화하게 하는 방식이다. |
| OBELICS | "Interleaved 웹 코퍼스" | 이미지와 텍스트가 읽기 순서대로 정렬된 1억 4,100만 개 웹 페이지 규모의 공개 데이터셋이다. |
| 친칠라(Chinchilla) | "70B 동결 베이스" | DeepMind의 Chinchilla 논문에서 나온, Flamingo의 동결된 텍스트 LLM이다. |
| 게이트 스케줄(Gate schedule) | "alpha가 움직이는 방식" | 학습 중 교차 어텐션 게이트가 열리는 속도와 패턴이다. |
| 교차 어텐션 빈도(Cross-attn frequency) | "M개 계층마다 한 번" | 게이트형 교차 어텐션 블록을 얼마나 자주 삽입하는지 나타낸다. Flamingo는 M=4를 사용한다. |
| OpenFlamingo | "오픈 재현 모델" | MosaicML/LAION의 3-9B 오픈 체크포인트로, 구조는 Flamingo와 동일하다. |
더 읽을거리