#default_exp vision.rect_augment #default_cls_lvl 3 #export from fastai2.core.imports import * from fastai2.test import * from fastai2.core import * from fastai2.data.transform import * from fastai2.data.pipeline import * from fastai2.data.source import * from fastai2.data.core import * from fastai2.vision.core import * from fastai2.vision.augment import * from fastai2.data.external import * #hide from nbdev.showdoc import showdoc path = untar_data(URLs.PETS) items = get_image_files(path/'images') labeller = RegexLabeller(pat = r'/([^/]+)_\d+.jpg$') split_idx = RandomSplitter()(items) tfms = [PILImage.create, [labeller, Categorize()]] tds = TfmdDS(items, tfms) im = tds[0][0]; im.shape #export class SortARSampler(BatchSampler): def __init__(self, ds, items=None, bs=32, grp_sz=1000, shuffle=False, drop_last=False): if not items: items=ds.items self.shapes = [Image.open(it).shape for it in items] self.sizes = [h*w for h,w in self.shapes] self.ars = [h/w for h,w in self.shapes] self.ds,self.grp_sz,self.bs,self.shuffle,self.drop_last = ds,round_multiple(grp_sz,bs),bs,shuffle,drop_last self.grp_sz = round_multiple(grp_sz,bs) # reverse argsort of sizes idxs = [i for i,o in sorted(enumerate(self.sizes), key=itemgetter(1), reverse=True)] # create approx equal sized groups no larger than `grp_sz` grps = [idxs[i:i+self.grp_sz] for i in range(0, len(idxs), self.grp_sz)] # sort within groups by aspect ratio self.grps = [sorted(g, key=lambda o:self.ars[o]) for g in grps] def __iter__(self): grps = self.grps if self.shuffle: grps = [shufflish(o) for o in grps] grps = [g[i:i+self.bs] for g in grps for i in range(0, len(g), self.bs)] if self.drop_last and len(grps[-1])!=self.bs: del(grps[-1]) # Shuffle all but first (so can have largest first) if self.shuffle: grps = random.sample(grps[1:], len(grps)-1) + [grps[0]] return iter(grps) def __len__(self): return (len(self.ds) if self.drop_last else (len(self.ds)+self.bs-1)) // self.bs samp = SortARSampler(tds, shuffle=False) test_eq(len(samp), (len(tds)-1)//32+1) itr = iter(samp) first = next(itr) i = 1 for last in itr: i += 1 test_eq(len(samp), i) first = [tds[i][0] for i in first] last = [tds[i][0] for i in last] #big images are first, smaller images last assert np.mean([im.n_px for im in last])*5 < np.mean([im.n_px for im in first]) #Higher aspect ratios are first assert np.mean([im.aspect for im in last])*2 < np.mean([im.aspect for im in first]) #In a batch with similar aspect ratio assert np.std([im.aspect for im in first]) < 0.1 assert np.std([im.aspect for im in last]) < 0.1 samp = SortARSampler(tds, shuffle=True) itr = iter(samp) first = next(itr) for last in itr: pass first = [tds[i][0] for i in first] last = [tds[i][0] for i in last] #In a batch with similar aspect ratio assert np.std([im.aspect for im in first]) < 0.1 assert np.std([im.aspect for im in last]) < 0.1 #export class ResizeCollate(TfmdCollate): def __init__(self, tfms=None, collate_fn=default_collate, sz=None, is_fixed_px=False, max_px=512*512, round_mult=None, rand_min_scale=None, rand_ratio_pct=None): super().__init__(tfms, collate_fn) self.round_mult,self.is_fixed_px,self.max_px = round_mult,is_fixed_px,max_px self.is_rand = rand_min_scale or rand_ratio_pct if self.is_rand: self.inv_ratio = 1-ifnone(rand_ratio_pct, 0.10) self.resize = RandomResizedCrop(1, min_scale=ifnone(rand_min_scale, 0.25), as_item=False) else: self.resize = Resize(1, as_item=False) self.sz = None if sz is None else (sz, sz) if isinstance(sz, int) else sz def __call__(self, samples): if self.sz is None: if self.is_fixed_px: px = self.max_px else: px = min(self.max_px, max(L(o[0].shape[0]*o[0].shape[1] for o in samples))) ar = np.median(L(o[0].aspect for o in samples)) sz = int(math.sqrt(px*ar)),int(math.sqrt(px/ar)) else: sz,ar = self.sz,self.sz[1]/self.sz[0] if self.round_mult is not None: sz = round_multiple(sz, self.round_mult, round_down=True) if self.is_rand: self.resize.ratio = (ar*self.inv_ratio, ar/self.inv_ratio) return super().__call__(self.resize(o,size=sz) for o in samples) samp = SortARSampler(tds, shuffle=True, bs=16) collate_fn = ResizeCollate(max_px=10000) tdl = TfmdDL(tds, batch_sampler=samp, collate_fn=collate_fn, num_workers=0) batch = tdl.one_batch() test_eq(L(batch).map(type), (TensorImage,Tensor)) b,c,h,w = batch[0].shape assert 9000