# default_exp distributed #export from fastai2.basics import * from fastai2.callback.progress import ProgressCallback from torch.nn.parallel import DistributedDataParallel, DataParallel from torch.utils.data.distributed import DistributedSampler #export @patch def reset(self: DataParallel): if hasattr(self.module, 'reset'): self.module.reset() #export @log_args class ParallelTrainer(Callback): run_after,run_before = TrainEvalCallback,Recorder def __init__(self, device_ids): self.device_ids = device_ids def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids) def after_fit(self): self.learn.model = self.learn.model.module #export @patch def to_parallel(self: Learner, device_ids=None): self.add_cb(ParallelTrainer(device_ids)) return self #export @patch def detach_parallel(self: Learner): "Remove ParallelTrainer callback from Learner." self.remove_cb(ParallelTrainer) return self #export @patch @contextmanager def parallel_ctx(self: Learner, device_ids=None): "A context manager to adapt a learner to train in data parallel mode." try: self.to_parallel(device_ids) yield self finally: self.detach_parallel() #export @patch def reset(self: DistributedDataParallel): if hasattr(self.module, 'reset'): self.module.reset() #export def setup_distrib(gpu=None): if gpu is None: return gpu gpu = int(gpu) torch.cuda.set_device(int(gpu)) if num_distrib() > 1: torch.distributed.init_process_group(backend='nccl', init_method='env://') return gpu #export def teardown_distrib(): if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() #export @log_args(but_as=TfmdDL.__init__) @delegates() class DistributedDL(TfmdDL): def __init__(self, dataset, rank, world_size, **kwargs): super().__init__(dataset, **kwargs) if self.n%world_size != 0: self.n += world_size-self.n%world_size self.total_n,self.n = self.n,self.n//world_size store_attr(self, 'rank,world_size') def get_idxs(self): idxs = Inf.count if self.indexed else Inf.nones return idxs if self.n is None else list(itertools.islice(idxs, self.total_n)) def shuffle_fn(self, idxs): "Deterministically shuffle on each training process based on epoch." g = torch.Generator() g.manual_seed(self.epoch) return L(idxs)[torch.randperm(self.total_n, generator=g)] def sample(self): idxs = self.get_idxs() if self.shuffle: idxs = self.shuffle_fn(idxs) # add extra samples to make it evenly divisible idxs += idxs[:(self.total_n - len(idxs))] # subsample idxs = idxs[self.rank:self.total_n:self.world_size] return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs) def create_item(self, s): if s is not None and s >= len(self.dataset): s = s%len(self.dataset) return s if hasattr(self.dataset, 'iloc') else super().create_item(s) def set_epoch(self, epoch): self.epoch = epoch @classmethod def from_dl(cls, dl, rank, world_size, **kwargs): cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout, bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed, device=dl.device) cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in "get_idxs sample shuffle_fn create_item".split()}) return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs)) dl = TfmdDL(list(range(50)), bs=16, num_workers=2) for i in range(4): dl1 = DistributedDL.from_dl(dl, i, 4) test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50) dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True) res = [] for i in range(4): dl1 = DistributedDL.from_dl(dl, i, 4) dl1.set_epoch(0) res += list(dl1)[0].tolist() #All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle test_eq(sorted(res), [0,0,1,1] + list(range(2, 50))) #export @log_args class DistributedTrainer(Callback): run_after,run_before = TrainEvalCallback,Recorder fup = None # for `find_unused_parameters` in DistributedDataParallel() def __init__(self, cuda_id=0): self.cuda_id = cuda_id def before_fit(self): opt_kwargs = { 'find_unused_parameters' : DistributedTrainer.fup } if DistributedTrainer.fup is not None else {} self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id, **opt_kwargs) self.old_dls = list(self.dls) self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls] if rank_distrib() > 0: self.learn.logger=noop def _wrap_dl(self, dl): return dl if isinstance(dl, DistributedDL) else DistributedDL.from_dl(dl, rank_distrib(), num_distrib()) def before_epoch(self): for dl in self.dls: dl.set_epoch(self.epoch) def before_train(self): self.learn.dl = self._wrap_dl(self.learn.dl) def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl) def after_fit(self): self.learn.model = self.learn.model.module self.learn.dls.loaders = self.old_dls #export @patch def to_distributed(self: Learner, cuda_id): self.add_cb(DistributedTrainer(cuda_id)) if rank_distrib() > 0: self.remove_cb(ProgressCallback) return self #export @patch def detach_distributed(self: Learner): if num_distrib() <=1: return self self.remove_cb(DistributedTrainer) if rank_distrib() > 0 and not hasattr(self, 'progress'): self.add_cb(ProgressCallback()) return self #export @patch @contextmanager def distrib_ctx(self: Learner, cuda_id=None): "A context manager to adapt a learner to train in distributed data parallel mode." # Figure out the GPU to use from rank. Create a dpg if none exists yet. if cuda_id is None: cuda_id = rank_distrib() if not torch.distributed.is_initialized(): setup_distrib(cuda_id) cleanup_dpg = torch.distributed.is_initialized() else: cleanup_dpg = False # Adapt self to DistributedDataParallel, yield, and cleanup afterwards. try: if num_distrib() > 1: self.to_distributed(cuda_id) yield self finally: self.detach_distributed() if cleanup_dpg: teardown_distrib() #export def rank0_first(func): "Execute `func` in the Rank-0 process first, then in other ranks in parallel." dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0) with dummy_l.distrib_ctx(): if rank_distrib() == 0: res = func() distrib_barrier() if rank_distrib() != 0: res = func() return res #hide from nbdev.export import notebook2script notebook2script()