[Pytorch] 분산 학습 3: 완전 분할 데이터 병렬 처리 (FSDP)

Date:     Updated:

카테고리:

완전 분할 데이터 병렬 처리 (Fully Sharded Data Parallel, FSDP)

FSDP의 개념

완전 분할 데이터 병렬 처리 (Fully Sharded Data Parallel, FSDP)는 모델 파라미터·그라디언트·옵티마 상태(optimizer states)를 GPU들에 샤딩하여 중복 메모리를 거의 제거하고, 계산과 통신을 겹쳐(overlap) 대형 모델을 단일 서버에서도 안정적으로 학습하게 해주는 PyTorch 표준 분산 학습 엔진이다. 이 때, GPU 수를 workers 혹은 rank라고 한다. DDP와 달리 FSDP는 각 GPU에 모델을 나눠서 복제하기 때문에 메모리 사용량이 훨씬 많이 줄어든다.

  • 정의: FSDP는 PyTorch의 공식 모듈로, 파라미터(param)·그라디언트(grad)·옵티마 상태(opt state)각 GPU에 분산 저장(샤딩) 하는 데이터 병렬의 발전형이다.
  • 핵심 목표: DDP가 갖는 “모든 GPU가 모델 전체 복사본을 보유”하는 중복 메모리를 제거하여 훨씬 큰 모델/배치/시퀀스 길이를 처리하는 것이다. 필요한 순간에만 조각을 모아(All-Gather) 계산하고, 계산이 끝나면 다시 흩뿌리고(Reduce-Scatter) 저장하여 메모리 상주량을 최소화한다.
  • PyTorch 2.1.0 이상, accelerate 필요

FSDP 동작 과정

FSDP는 모듈(레이어) 단위로 래핑되어 동작하며, 한 모듈의 파라미터는 GPU 개수(world size)만큼 쪼개져 각 GPU에 1/𝑁씩 보관된다.

  • Forward 전 준비: 현재 실행할 모듈의 파라미터를 All-Gather하여 일시적으로 완전한 파라미터를 메모리에 모음(연산 가능한 상태)입니다.
  • Forward/Backward 계산: 해당 모듈의 전·후향을 수행한다.
  • Backward 후 정리: 그라디언트를 Reduce-Scatter 하여 다시 샤드 단위로 나누어 각 GPU에 보관한다. 필요 시 파라미터를 즉시 해제(unshard → reshard) 하여 피크 메모리를 낮춘다.
  • Optimizer Step: 옵티마 상태 역시 샤딩되어 있어 각 GPU는 자신에게 있는 파라미터 샤드만 업데이트한다. 결과적으로 모든 GPU에 분산된 샤드들이 한 스텝 끝에 일관되게 갱신된다.

“필요할 때만 모았다가, 끝나면 바로 흩뿌린다” 로 이해하면 된다. 이를 통해 메모리 상주량 최소화 + 통신-계산 오버랩을 달성한다.

실제 GPU 사용량 비교

  • 모델 VRAM 요구량: 12GB (단일 GPU 기준)
  • 학습 환경: 4x GPU (각 GPU VRAM 용량 48GB 이상 가정)
  • 옵티마이저: Adam (파라미터의 2배 메모리 사용)

1

DDP 사용 시 GPU 당 모델 관련해서 단일 GPU의 VRAM 만 48GB가 필요하다.

1

  • 4등분 샤딩으로 GPU 당 모델 관련 VRAM 만 12 GB 필요.
  • 단, 순전파/역전파 시 임시적으로 All-Gather 를 위해 모델 내 가장 큰 Layer 의 VRAM 인 a GB 만큼 여분 필요.
  • 12+a + b(데이터셋 리소스)GB 만큼의 GPU VRAM 이면 모델 학습 가능.

주요 옵션과 개념

  • ShardingStrategy
    • FULL_SHARD: 파라미터·그라드·옵티마 모두 샤딩하는 기본·권장 전략이다.
    • SHARD_GRAD_OP: 일부만 샤딩하는 절충안(과거 호환/특수 케이스).
  • MixedPrecision
    • 파라미터/그라드/통신 dtype을 세밀하게 지정해 BF16/FP16로 메모리·대역폭을 절감한다.
  • CPUOffload
    • 일부 파라미터를 CPU로 내렸다 올릴 수 있으나, 단일 서버에서는 NVLink/NVSwitch 환경이 아니라면 속도 저하가 흔하다. 필요할 때만 신중히 사용한다.
  • Param Flattening & Bucketing
    • 다수의 작은 텐서를 플랫하게 합쳐 통신/할당 오버헤드를 줄인다.
    • 버킷 크기 조절로 통신-계산 오버랩 정도를 조정한다.
  • Auto Wrap Policy
    • 모듈 크기/타입 기반 자동 래핑으로, 적절한 파티션 경계를 만들고 피크 메모리를 낮춘다.
  • State Dict 유형
    • full_state_dict: 모든 샤드를 모아 완전한 체크포인트(보통 rank 0에서 저장)
    • sharded_state_dict: 샤드 형태로 분산 저장/로드(대규모에서 효율적)
    • local_state_dict: 각 랭크 로컬 상태(특수 관리용)



Reference

Blog: FSDP 쉽게 설명하기

Pytorch 카테고리 내 다른 글 보러가기

댓글 남기기