# default_exp distributed
#export
from fastai2.basics import *
from fastai2.callback.progress import ProgressCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.utils.data.distributed import DistributedSampler
Callbacks and helper functions to train in parallel or use distributed training
Patch the parallel models so they work with RNNs
#export
@patch
def reset(self: DataParallel):
if hasattr(self.module, 'reset'): self.module.reset()
#export
@log_args
class ParallelTrainer(Callback):
run_after,run_before = TrainEvalCallback,Recorder
def __init__(self, device_ids): self.device_ids = device_ids
def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)
def after_fit(self): self.learn.model = self.learn.model.module
#export
@patch
def to_parallel(self: Learner, device_ids=None):
self.add_cb(ParallelTrainer(device_ids))
return self
#export
@patch
def detach_parallel(self: Learner):
"Remove ParallelTrainer callback from Learner."
self.remove_cb(ParallelTrainer)
return self
#export
@patch
@contextmanager
def parallel_ctx(self: Learner, device_ids=None):
"A context manager to adapt a learner to train in data parallel mode."
try:
self.to_parallel(device_ids)
yield self
finally:
self.detach_parallel()
Patch the parallel models so they work with RNNs
#export
@patch
def reset(self: DistributedDataParallel):
if hasattr(self.module, 'reset'): self.module.reset()
Convenience functions to set up/tear down torch distributed data parallel mode.
#export
def setup_distrib(gpu=None):
if gpu is None: return gpu
gpu = int(gpu)
torch.cuda.set_device(int(gpu))
if num_distrib() > 1:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
return gpu
#export
def teardown_distrib():
if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()
We need to change the dataloaders so that they only get one part of the batch each (otherwise there is no point in using distributed training).
#export
@log_args(but_as=TfmdDL.__init__)
@delegates()
class DistributedDL(TfmdDL):
def __init__(self, dataset, rank, world_size, **kwargs):
super().__init__(dataset, **kwargs)
if self.n%world_size != 0: self.n += world_size-self.n%world_size
self.total_n,self.n = self.n,self.n//world_size
store_attr(self, 'rank,world_size')
def get_idxs(self):
idxs = Inf.count if self.indexed else Inf.nones
return idxs if self.n is None else list(itertools.islice(idxs, self.total_n))
def shuffle_fn(self, idxs):
"Deterministically shuffle on each training process based on epoch."
g = torch.Generator()
g.manual_seed(self.epoch)
return L(idxs)[torch.randperm(self.total_n, generator=g)]
def sample(self):
idxs = self.get_idxs()
if self.shuffle: idxs = self.shuffle_fn(idxs)
# add extra samples to make it evenly divisible
idxs += idxs[:(self.total_n - len(idxs))]
# subsample
idxs = idxs[self.rank:self.total_n:self.world_size]
return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
def create_item(self, s):
if s is not None and s >= len(self.dataset): s = s%len(self.dataset)
return s if hasattr(self.dataset, 'iloc') else super().create_item(s)
def set_epoch(self, epoch): self.epoch = epoch
@classmethod
def from_dl(cls, dl, rank, world_size, **kwargs):
cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout,
bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed, device=dl.device)
cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in "get_idxs sample shuffle_fn create_item".split()})
return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs))
dl = TfmdDL(list(range(50)), bs=16, num_workers=2)
for i in range(4):
dl1 = DistributedDL.from_dl(dl, i, 4)
test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)
dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)
res = []
for i in range(4):
dl1 = DistributedDL.from_dl(dl, i, 4)
dl1.set_epoch(0)
res += list(dl1)[0].tolist()
#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle
test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))
#export
@log_args
class DistributedTrainer(Callback):
run_after,run_before = TrainEvalCallback,Recorder
fup = None # for `find_unused_parameters` in DistributedDataParallel()
def __init__(self, cuda_id=0): self.cuda_id = cuda_id
def before_fit(self):
opt_kwargs = { 'find_unused_parameters' : DistributedTrainer.fup } if DistributedTrainer.fup is not None else {}
self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id, **opt_kwargs)
self.old_dls = list(self.dls)
self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls]
if rank_distrib() > 0: self.learn.logger=noop
def _wrap_dl(self, dl):
return dl if isinstance(dl, DistributedDL) else DistributedDL.from_dl(dl, rank_distrib(), num_distrib())
def before_epoch(self):
for dl in self.dls: dl.set_epoch(self.epoch)
def before_train(self): self.learn.dl = self._wrap_dl(self.learn.dl)
def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl)
def after_fit(self):
self.learn.model = self.learn.model.module
self.learn.dls.loaders = self.old_dls
Attach, remove a callback which adapts the model to use DistributedDL to train in distributed data parallel mode.
#export
@patch
def to_distributed(self: Learner, cuda_id):
self.add_cb(DistributedTrainer(cuda_id))
if rank_distrib() > 0: self.remove_cb(ProgressCallback)
return self
#export
@patch
def detach_distributed(self: Learner):
if num_distrib() <=1: return self
self.remove_cb(DistributedTrainer)
if rank_distrib() > 0 and not hasattr(self, 'progress'): self.add_cb(ProgressCallback())
return self
#export
@patch
@contextmanager
def distrib_ctx(self: Learner, cuda_id=None):
"A context manager to adapt a learner to train in distributed data parallel mode."
# Figure out the GPU to use from rank. Create a dpg if none exists yet.
if cuda_id is None: cuda_id = rank_distrib()
if not torch.distributed.is_initialized():
setup_distrib(cuda_id)
cleanup_dpg = torch.distributed.is_initialized()
else: cleanup_dpg = False
# Adapt self to DistributedDataParallel, yield, and cleanup afterwards.
try:
if num_distrib() > 1: self.to_distributed(cuda_id)
yield self
finally:
self.detach_distributed()
if cleanup_dpg: teardown_distrib()
distrib_ctx context manager¶distrib_ctx(cuda_id) prepares a learner to train in distributed data parallel mode. It assumes these environment variables have all been setup properly, such as those launched by python -m fastai2.launch.
with learn.distrib_ctx(): learn.fit(.....)
It attaches a DistributedTrainer callback and DistributedDL data loader to the learner, then executes learn.fit(.....). Upon exiting the context, it removes the DistributedTrainer and DistributedDL, and destroys any locally created distributed process group. The process is still attached to the GPU though.
#export
def rank0_first(func):
"Execute `func` in the Rank-0 process first, then in other ranks in parallel."
dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0)
with dummy_l.distrib_ctx():
if rank_distrib() == 0: res = func()
distrib_barrier()
if rank_distrib() != 0: res = func()
return res
rank0_first(f) calls f() in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, f() is called only once as expected.
One application of rank0_first() is to make fresh downloads via untar_data() safe in distributed training scripts launched by python -m fastai2.launch <script>:
path = untar_data(URLs.IMDB)
becomes:
path = rank0_first(lambda: untar_data(URLs.IMDB))
Some learner factory methods may use untar_data() to download pretrained models by default:
learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
becomes:
learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))
Otherwise, multiple processes will download at the same time and corrupt the data.
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.cutmix.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted index.ipynb. Converted tutorial.ipynb.