webdataset Streaming¶This is an example notebook to show how to use the webdataset format in streaming mode.
We'll use the VQGAN Pairs dataset. It is a custom collection of about 2M image pairs that I prepared for super-resolution and image enhancement tests.
Each entry contains just two images:
input image.decoded image.Webdataset uses the ancient tar file format used for tape archival purposes. It is just a collection of files with no other metadata. Webdataset creates virtual records by looking at the file names. For example, given the following 4 files:
11329763_91928ff2b7_o.decoded.jpg
11329763_91928ff2b7_o.input.jpg
11349915304_05d4459fae_o.decoded.jpg
11349915304_05d4459fae_o.input.jpg
webdataset would match the filenames (ignoring the extension), and expose them as two records with entries called decoded.jpg and input.jpg. If this dataset had captions stored in a file with extension .txt, then another entry would be added with the name txt.
Note: the tar format is sequential. When preparing the dataset make sure that files belonging to the same record are stored one after the other.
We specify the names of the files we want to download using brace expansion.
import braceexpand
import webdataset as wds
files = "{00001..00954}.tar"
urls = f"https://huggingface.co/datasets/dalle-mini/vqgan-pairs/resolve/main/data/{files}"
Let's take an initial look at the API.
pil_dataset = (
wds.WebDataset(urls, handler=wds.warn_and_continue) # Handler is optional
.decode("pil") # Decode potential images as PIL
.to_tuple("decoded.jpg", "input.jpg") # Place _only_ these entries in a tuple
.batched(4)) # Return batches of 4 items
%%time
decoded, inputs = next(iter(pil_dataset))
CPU times: user 78.9 ms, sys: 21.4 ms, total: 100 ms Wall time: 1.45 s
decoded[0]
(The decoded quality is very crappy because we used a VQGAN with a very high f16 compression factor).
inputs[0]
This is a more realistic use: we'll be producing batches of tensors.
You can use any normalization or preparation tasks you need for the model you'll use.
import torch
import torchvision.transforms as T
bs = 32
preprocess_image = T.Compose([
T.ToTensor(),
lambda t: t.permute(1, 2, 0) # Reorder, if needed, or do whatever stuff you need.
])
def preprocess(sample):
# We are changing the keys too (decoded.jpg -> decoded)
return {
"decoded": preprocess_image(sample["decoded.jpg"]),
"input": preprocess_image(sample["input.jpg"])
}
dataset = (wds.WebDataset(urls, handler=wds.warn_and_continue)
.shuffle(2500)
.decode("pil")
.map(preprocess)
.to_tuple("decoded", "input")
.batched(bs))
decoded, inputs = next(iter(dataset))
decoded.shape, inputs.shape
(torch.Size([32, 256, 256, 3]), torch.Size([32, 512, 512, 3]))
It is usually recommended to avoid batching in the DataLoader; instead, batch the dataset and then rebatch (if necessary) after the loader.
num_workers = 8
dl = wds.WebLoader(
dataset,
batch_size=None,
num_workers=num_workers
)
dl = dl.unbatched().batched(bs) # batch size could be different here
We are now ready to use this dataloader in a training loop.
import itertools
from tqdm import tqdm
def do_loop(dl):
max_samples = 3200
for (decoded, inputs) in tqdm(itertools.islice(dl, max_samples // bs), total=max_samples//bs):
# do something with these images
pass
do_loop(dl)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:19<00:00, 5.16it/s]