Notice
Recent Posts
Recent Comments
Link
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
Archives
Today
Total
관리 메뉴

끊김 없이 하자

Pytorch DataParallel 사용 시의 팁 본문

낑낑

Pytorch DataParallel 사용 시의 팁

도파미파 2021. 8. 17. 11:39

(내가 보려고) 당근마켓 팀블로그의 2019년 게시글을 약식 요약.
[원본] PyTorch Multi-GPU 제대로 학습하기 https://medium.com/daangn/pytorch-multi-gpu-%ED%95%99%EC%8A%B5-%EC%A0%9C%EB%8C%80%EB%A1%9C-%ED%95%98%EA%B8%B0-27270617936b

 

🔥PyTorch Multi-GPU 학습 제대로 하기

PyTorch를 사용해서 Multi-GPU 학습을 하는 과정을 정리했습니다. 이 포스트는 다음과 같이 진행합니다.

medium.com

 

1. DataParallel( )은 기본적으로 여러개의 GPU에서 계산되는 gradient를 하나의 GPU로 모아서 계산한다.
   gradient 연산에서는 계속 scatter - gather 가 일어난다. 

2. (1단계 팁) 이때, 출력(loss 계산에 필요)까지 같은 GPU로 모으면 하나의 GPU가 과도하게 GPU 메모리를 점거하는데 일조하게 되므로,
   다른 GPU를 지정해주면 좋다. (out of memory 가능성 줄임 -> 더 큰 배치사이즈를 지정할 수 있게 된다.)

os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'
model = nn.DataParallel(model, output_device=1)

3. (2단계 팁) 2는 GPU간 데이터 불균형에 대한 적절한 해결책은 아니다. Custom DataParallel을 사용하면 좋다. 
   pytorch-encoding 패키지(일례)를 사용하여 loss function까지 병렬 연산할 수 있도록 만든다. 

기존의 DataParallel.data_parallel()

def data_parallel(module, input, device_ids, output_device):
    replicas = nn.parallel.replicate(module, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    return nn.parallel.gather(outputs, output_device)

Pytorch-encoding의  DataParallelCriterion

from torch.nn.parallel.data_parallel import DataParallel

class DataParallelCriterion(DataParallel):
    def forward(self, inputs, *targets, **kwargs):
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids) # new
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        targets = tuple(targets_per_gpu[0] for targets_per_gpu in targets) # new
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) # substituted
        return Reduce.apply(*outputs) / len(outputs), targets

이용 예시

import torch
import torch.nn as nn
from parallel import DataParallelModel, DataParallelCriterion

model = BERT(args)
model = DataParallelModel(model)
model.cuda()

criterion = nn.NLLLoss()
criterion = DataParallelCriterion(criterion) 

...

for i, (inputs, labels) in enumerate(trainloader):
    outputs = model(inputs)          
    loss = criterion(outputs, labels)     
    
    optimizer.zero_grad()
    loss.backward()                        
    optimizer.step()

 

4. (3단계 팀) 여러 컴퓨터를 사용해서 분산학습 하는 경우

    torch.distributed 모듈을 사용한다. DistributedDataParallel과 DistributedSampler를 함께 사용한다.
    => 2019년 기준 학습에 사용하지 않는 파라미터의 존재 여부에 따른 에러가 있어서, Apex를 이용한 방법도 소개되어있다.
    (pip install 하면 다른 apex가 깔림! 공식 github의 설치 방법 참고할 것)
    당근 마켓의 예시에는 torch.distributed와 같이 사용되어 있는데 막상 Apex의 DDP 모듈을 사용할 때는 args를 다 넘겨주지 않아서... args의 변수 몇 개를 설정해주는 이상의 의미가 있는가? 는 잘 모르겠다. 

 

마지막으로 원문의 정리 팁을 보면, 모델 출력이 큰 모델일수록 불균형 문제가 심각할 수 있다고 한다. 원문의 예시는 BERT 였지만 음성 인식/합성 모델의 경우에도 고려해야할 문제로 보인다. 

 

아래는 나중에 다시 들여다보게 되었을 때를 위한 관련 링크.

- pytorch tutorials Distributed Data Parallel
  (영) https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

  (한) https://tutorials.pytorch.kr/intermediate/dist_tuto.html

 

 

여담으로 world는 'group of processes', rank 는 그룹 내 프로세스 고유의 번호로 보인다.
따라서 world_size = 4라면 rank = [0, 1, 2, 3] 이다.

local_rank는 각 프로세스가 속한 GPU의 번호이다. 
nr은 머신(컴퓨터)간 우선순위를 표시하는 데 쓰이는 것으로 보인다. (🤔❔)

참고: https://stackoverflow.com/questions/58271635/in-distributed-computing-what-are-world-size-and-rank

 

이제 쪼금 공부했으니 기존 코드를 뜯어보면서 어떻게 수정할 수 있을지 보면 될 듯...
단일 머신 상황에서는 굳이 dist까지 사용할 필요가 없는 것으로 보이는데, 이와 관련된 공부는 또 다음 포스팅에..

Comments