%reload_ext autoreload
%autoreload 2
#export
from nb_003a import *
from itertools import groupby
DATA_PATH = Path('data')
PATH = DATA_PATH/'caltech101'
np.random.seed(42)
train_ds,valid_ds = ImageDataset.from_folder(PATH, test_pct=0.2)
x = train_ds[-1][0]
classes = train_ds.classes
c = len(classes)
len(train_ds),len(valid_ds),c
Since we are dealing with different sized rectangular images we need to standarize size to be able to train our network. We will start by comparing the dimension ratio for each of the images in our dataset. The dimension ratio is the height of our image divided by its width. Our final objective is to group images by dimension ratio and then standarize the dimensions for all images in the same group. This is useful because it means that when training we will be able to feed our network a standarized batch of images in each iteration. This does not mean all batches have to have the same dimensions (in fact, we will have a number of distinct dimensions equal to the number of distinct groups) but we do need images in a single batch to have the same dimensions.
In the following example, we chose to divide our images in 5 groups so we defined 5 percentiles to do this: 2, 20, 50, 80, 98.
show_image(train_ds[1][0], figsize=(6,3))
x.shape
asp_ratios = [operator.truediv(*PIL.Image.open(fn).size) for fn in train_ds.fns]
asp_ratios[:4]
asp_ntiles = np.percentile(asp_ratios, [2,20,50,80,98])
asp_ntiles
#export
def closest_ntile(aspect, ntiles):
return ntiles[np.argmin(abs(log(aspect)-log(ntiles)))]
aspect = x.shape[2]/x.shape[1]
nearest_aspect = closest_ntile(aspect, asp_ntiles)
aspect,nearest_aspect
get_crop_target(128, nearest_aspect)
We will now sort our images by aspect ratio and then group them by percentile group. We will build a Batch Sampler that samples images from the same group in a random way.
Notice that our SortAspectBatchSampler returns the image number and its group aspect ratio. When we create our DataLoader, this extra parameter will be fed into our DatasetTfm function. It will enable our crop_pad transform defined in 003a_rect_images to transform each of the images into their group aspect ratio so we finally have only 5 distinct aspect ratios in our dataset.
asp_nearests = [closest_ntile(o, asp_ntiles) for o in asp_ratios]
asp_nearests[:10]
bs=32
sort_nearest = sorted(enumerate(asp_nearests), key=itemgetter(1))
sort_nearest[:5]
groups = [[(a,{'aspect':b}) for a,b in o] for _,o in groupby(sort_nearest, key=itemgetter(1))]
len(groups)
groups[0][:4]
sum(math.ceil(len(g)/bs) for g in groups)
# TODO: use actual AR for shuffle=False, and use mean AR in __iter__
#export
@dataclass
class SortAspectBatchSampler(Sampler):
ds:Dataset; bs:int; shuffle:bool = False
def __post_init__(self):
asp_ratios = [operator.truediv(*PIL.Image.open(img).size) for img in self.ds.fns]
asp_ntiles = np.percentile(asp_ratios, [2,20,50,80,98])
asp_nearests = [closest_ntile(o, asp_ntiles) for o in asp_ratios]
sort_nearest = sorted(enumerate(asp_nearests), key=itemgetter(1))
self.groups = [[(a,{'aspect':b}) for a,b in o]
for _,o in groupby(sort_nearest, key=itemgetter(1))]
self.n = sum(math.ceil(len(g)/self.bs) for g in self.groups)
def __len__(self): return self.n
def __iter__(self):
if self.shuffle: groups = [sample(group, len(group)) for group in self.groups]
else: groups = self.groups
batches = [group[i:i+self.bs] for group in groups for i in range(0, len(group), self.bs)]
if self.shuffle: batches = sample(batches, len(batches))
return iter(batches)
next(iter(SortAspectBatchSampler(train_ds, 4)))
it = iter(SortAspectBatchSampler(train_ds, 4, True))
next(it),next(it)
We will first build a function that applies the transforms to our raw images. We will then use a DataBunch function that returns a DataLoader which loads our transformed data in batches. This function will be integrated with our SortAspectBatchSampler so that the images on one batch are transformed not only to have the same dimensions (and by definition, the same aspect ratio).
#export
class DatasetTfm(Dataset):
def __init__(self, ds: Dataset, tfms: Collection[Callable] = None, **kwargs):
self.ds,self.tfms,self.kwargs = ds,tfms,kwargs
def __len__(self): return len(self.ds)
def __getattr__(self, k): return getattr(self.ds, k)
def __getitem__(self,idx):
if isinstance(idx, tuple): idx,xtra = idx
else: xtra={}
x,y = self.ds[idx]
return apply_tfms(self.tfms, x, **{**self.kwargs, **xtra}), y
def rand_zoom_crop(scale, size):
return [rand_zoom(scale=scale), rand_crop(size=size)]
def zoom_crop(scale, size):
return [zoom(scale=scale), crop(size=size)]
train_tfms = [rotate(degrees=(-20,20.)),
*rand_zoom_crop(scale=(1.,2.), size=150)]
valid_tfms = [crop_pad(size=150)]
train_tds = DatasetTfm(train_ds, train_tfms)
valid_tds = DatasetTfm(valid_ds, valid_tfms)
xtra = {'size':100}
train_tds[(1,xtra)][0].shape
_,axes = plt.subplots(2,2, figsize=(8,6))
for i,ax in enumerate(axes.flat): show_image(valid_tds[i][0], ax)
_,axes = plt.subplots(2,2, figsize=(8,6))
for ax in axes.flat: show_image(train_tds[1][0], ax, hide_axis=False)
_,axes = plt.subplots(2,2, figsize=(8,6))
for i,ax in enumerate(axes.flat):
im = train_tds[(i, xtra)][0]
print(im.shape)
show_image(im, ax, hide_axis=False)
train_tds = DatasetTfm(train_ds, train_tfms, size=100, padding_mode='zeros')
valid_tds = DatasetTfm(valid_ds, valid_tfms, size=100, padding_mode='zeros')
len(train_tds), len(valid_tds)
_,axes = plt.subplots(2,2, figsize=(8,6))
for i,ax in enumerate(axes.flat):
im = train_tds[i][0]
print(im.shape)
show_image(im, ax, hide_axis=False)
#export
class DataBunch():
def __init__(self, train_dl, valid_dl, device=None, **kwargs):
self.device = default_device if device is None else device
self.train_dl = DeviceDataLoader(train_dl, device=self.device, **kwargs)
self.valid_dl = DeviceDataLoader(valid_dl, device=self.device, **kwargs)
@classmethod
def create(cls, train_ds, valid_ds, bs=64, device=None, num_workers=4, progress_func=tqdm,
train_tfm=None, valid_tfm=None, sample_func=None, dl_tfms=None, **kwargs):
if train_tfm is not None: train_tfm = DatasetTfm(train_ds, train_tfm, **kwargs)
if valid_tfm is not None: valid_tfm = DatasetTfm(valid_ds, valid_tfm, **kwargs)
if sample_func is None:
train_dl = DataLoader(train_ds, bs, shuffle=True, num_workers=num_workers)
valid_dl = DataLoader(valid_ds, bs*2, shuffle=False, num_workers=num_workers)
else:
train_samp = sample_func(train_ds, bs, True)
valid_samp = sample_func(valid_ds, bs*2, False)
train_dl = DataLoader(train_ds, num_workers=num_workers, batch_sampler=train_samp)
valid_dl = DataLoader(valid_ds, num_workers=num_workers, batch_sampler=valid_samp)
return cls(train_dl, valid_dl, device, tfms=dl_tfms, progress_func=progress_func)
@property
def train_ds(self): return self.train_dl.dl.dataset
@property
def valid_ds(self): return self.valid_dl.dl.dataset
data = DataBunch.create(train_tds, valid_tds, bs, num_workers=8,
train_tfm=train_tfms, valid_tfm=valid_tfms, size=100, padding_mode='zeros')
x,y = next(iter(data.train_dl))
print(x[0].shape)
_,axes = plt.subplots(2,4, figsize=(9,3))
for i,ax in enumerate(axes.flat): show_image(x[i], ax)
data = DataBunch.create(train_tds, valid_tds, bs, num_workers=8, sample_func=SortAspectBatchSampler,
train_tfm=train_tfms, valid_tfm=valid_tfms, padding_mode='zeros')
x,y = next(iter(data.train_dl))
print(x[0].shape)
_,axes = plt.subplots(2,4, figsize=(12,4))
for i,ax in enumerate(axes.flat): show_image(x[i], ax, hide_axis=False)