This module defines the basic DataBunch object that is used inside Learner to train a model. This is the generic class, that can take any kind of fastai Dataset or DataLoader. You'll find helpful functions in the data module of every application to directly create this DataBunch for you.
from fastai.gen_doc.nbdoc import *
from fastai.basics import *
show_doc(DataBunch)
class DataBunch[source]
DataBunch(train_dl:DataLoader,valid_dl:DataLoader,fix_dl:DataLoader=*None,test_dl:Optional[DataLoader]=None,device:device=None,dl_tfms:Optional[Collection[Callable]]=None,path:PathOrStr='.',collate_fn:Callable='data_collate',no_check:bool=False*)
Bind train_dl,valid_dl and test_dl in a a data object.
It also ensure all the dataloaders are on device and apply to them tfms as batch are drawn (like normalization). path is used internally to store temporary files, collate_fn is passed to the pytorch Dataloader (replacing the one there) to explain how to collate the samples picked for a batch. By default, it applies data to the object sent (see in vision.image or the data block API why this can be important).
train_dl, valid_dl and optionally test_dl will be wrapped in DeviceDataLoader.
show_doc(DataBunch.create)
create[source]
create(train_ds:Dataset,valid_ds:Dataset,test_ds:Optional[Dataset]=*None,path:PathOrStr='.',bs:int=64,val_bs:int=None,num_workers:int=4,dl_tfms:Optional[Collection[Callable]]=None,device:device=None,collate_fn:Callable='data_collate',no_check:bool=False*) →DataBunch
Create a DataBunch from train_ds, valid_ds and maybe test_ds with a batch size of bs.
num_workers is the number of CPUs to use, tfms, device and collate_fn are passed to the init method.
jekyll_warn("You can pass regular pytorch Dataset here, but they'll require more attributes than the basic ones to work with the library. See below for more details.")
show_doc(DataBunch.show_batch)
show_batch[source]
show_batch(rows:int=*5,ds_type:DatasetType=<DatasetType.Train: 1>, ***kwargs**)
Show a batch of data in ds_type on a few rows.
show_doc(DataBunch.dl)
dl[source]
dl(ds_type:DatasetType=*<DatasetType.Valid: 2>*) →DeviceDataLoader
Returns appropriate Dataset for validation, training, or test (ds_type).
show_doc(DataBunch.one_batch)
one_batch[source]
one_batch(ds_type:DatasetType=*<DatasetType.Train: 1>,detach:bool=True,denorm:bool=True,cpu:bool=True*) →Collection[Tensor]
Get one batch from the data loader of ds_type. Optionally detach and denorm.
show_doc(DataBunch.one_item)
one_item[source]
one_item(item,detach:bool=*False,denorm:bool=False,cpu:bool=False*)
Get item into a batch. Optionally detach and denorm.
show_doc(DataBunch.sanity_check)
sanity_check[source]
sanity_check()
Check the underlying data in the training set can be properly loaded.
show_doc(DataBunch.export)
export[source]
export(fname:str=*'export.pkl'*)
Export the minimal state of self for inference in self.path/fname.
show_doc(DataBunch.load_empty, full_name='load_empty')
show_doc(DataBunch.add_tfm)
add_tfm[source]
add_tfm(tfm:Callable)
Adds a transform to all dataloaders.
If you want to use your pytorch Dataset in fastai, you may need to implement more attributes/methods if you want to use the full functionality of the library. Some functions can easily be used with your pytorch Dataset if you just add an attribute, for others, the best would be to create your own ItemList by following this tutorial. Here is a full list of what the library will expect.
First of all, you obviously need to implement the methods __len__ and __getitem__, as indicated by the pytorch docs. Then the most needed things would be:
c attribute: it's used in most functions that directly create a Learner (tabular_learner, text_classifier_learner, unet_learner, create_cnn) and represents the number of outputs of the final layer of your model (also the number of classes if applicable).classes attribute: it's used by ClassificationInterpretation and also in collab_learner (best to use CollabDataBunch.from_df than a pytorch Dataset) and represents the unique tags that appear in your data.loss_func attribute: that is going to be used by Learner as a default loss function, so if you know your custom Dataset requires a particular loss, you can put it.In text, your dataset will need to have a vocab attribute that should be an instance of Vocab. It's used by text_classifier_learner and language_model_learner when building the model.
In tabular, your dataset will need to have a cont_names attribute (for the names of continuous variables) and a get_emb_szs method that returns a list of tuple (n_classes, emb_sz) representing, for each categorical variable, the number of different codes (don't forget to add 1 for nan) and the corresponding embedding size. Those two are used with the c attribute by tabular_learner.
To make those last functions work, you really need to use the data block API and maybe write your own custom ItemList.
DataBunch.show_batch (requires .x.reconstruct, .y.reconstruct and .x.show_xys)Learner.predict (requires x.set_item, .y.analyze_pred, .y.reconstruct and maybe .x.reconstruct)Learner.show_results (requires x.reconstruct, y.analyze_pred, y.reconstruct and x.show_xyzs)DataBunch.set_item (requires x.set_item)Learner.backward (uses DataBunch.set_item)DataBunch.export (requires export)show_doc(DeviceDataLoader)
class DeviceDataLoader[source]
DeviceDataLoader(dl:DataLoader,device:device,tfms:List[Callable]=*None,collate_fn:Callable='data_collate'*)
Bind a DataLoader to a torch.device.
Put the batches of dl on device after applying an optional list of tfms. collate_fn will replace the one of dl. All dataloaders of a DataBunch are of this type.
show_doc(DeviceDataLoader.create)
The given collate_fn will be used to put the samples together in one batch (by default it grabs their data attribute). shuffle means the dataloader will take the samples randomly if that flag is set to True, or in the right order otherwise. tfms are passed to the init method. All kwargs are passed to the pytorch DataLoader class initialization.
show_doc(DeviceDataLoader.add_tfm)
show_doc(DeviceDataLoader.remove_tfm)
show_doc(DeviceDataLoader.new)
show_doc(DeviceDataLoader.proc_batch)
show_doc(DatasetType, doc_string=False)
Enum= [Train, Valid, Test, Single, Fix]
Internal enumerator to name the training, validation and test dataset/dataloader.
show_doc(DeviceDataLoader.collate_fn)