Specifying data module¶
Bases: ABC
, LightningDataModule
, LoggerMixin
The class with the logic for dataset management.
The class provides a user with a simple interface:
prepare_data
is a method for downloading and preprocessing datasets
prepare_traindatasets
is a method returning torch.Dataset
for train
data
prepare_valdatasets
is a method returning torch.Dataset
for validation
data
prepare_trainvaldatasets
is a method returning a tuple of two
torch.Dataset
s for train and validation data
(if this method is provided, prepare_traindatasets
and
prepare_valdatasets
shouldn't be implemented)
prepare_testdataset
is a method returning torch.Dataset
for test data
prepare_predictdataset
is a method returning torch.Dataset
for data
for prediction
Source code in kit4dl/nn/dataset.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|
prepare_data ¶
prepare_data()
Prepare dataset for train/validation/test/predict splits.
Examples¶
class MyDatamodule(Kit4DLAbstractDataModule):
def prepare_data(self):
# any logic you need to perform before creating splits
download_dataset()
Source code in kit4dl/nn/dataset.py
67 68 69 70 71 72 73 74 75 76 77 78 79 |
|
prepare_traindataset ¶
Prepare dataset for training.
Parameters¶
args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset
Returns¶
train_dataset : Dataset A training dataset
Examples¶
...
def prepare_traindatasets(self, root_dir: str) -> Dataset:
train_dset = MyDataset(root_dir=root_dir)
return train_dset
Source code in kit4dl/nn/dataset.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
prepare_valdataset ¶
Prepare dataset for validation.
Parameters¶
args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset
Returns¶
val_dataset : Dataset A validation dataset
Examples¶
...
def prepare_valdatasets(self, root_dir: str) -> Dataset:
val_dset = MyDataset(root_dir=root_dir)
return val_dset
Source code in kit4dl/nn/dataset.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
|
prepare_trainvaldatasets ¶
prepare_trainvaldatasets(*args: Any, **kwargs: Any) -> Generator[tuple[Dataset, Dataset], None, None]
Prepare dataset for training and validation.
Parameters¶
args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset
Returns¶
trainval_dataset_generators : tuple of two Datasets and a string Tuple consisting of train and validation dataset, and a suffix for the split
Examples¶
...
def prepare_trainvaldatasets(self, root_dir: str) -> tuple[Dataset, Dataset]:
dset = MyDataset(root_dir=root_dir)
train_dset, val_dset = random_split(dset, [500, 50])
return train_dset, val_dset
Source code in kit4dl/nn/dataset.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
|
prepare_testdataset ¶
Prepare dataset for testing.
Parameters¶
args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset
Returns¶
test_datasets : Dataset A test dataset
Examples¶
...
def prepare_testdataset(self, root_dir: str) -> Dataset:
return MyDataset(root_dir=root_dir)
Source code in kit4dl/nn/dataset.py
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
prepare_predictdataset ¶
Prepare dataset for predicting.
Parameters¶
args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset
Returns¶
pred_datasets : Dataset A prediction dataset
Examples¶
...
def prepare_predictdataset(self, root_dir: str) -> Dataset:
return MyDataset(root_dir=root_dir)
Source code in kit4dl/nn/dataset.py
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
|
setup ¶
setup(stage: str) -> None
Set up data depending on the stage.
The method should not be overriden unless necessary.
Parameters¶
stage : str
The stage of the pipeline. One out of ['fit', 'test', 'predict']
Source code in kit4dl/nn/dataset.py
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
|
get_collate_fn ¶
get_collate_fn() -> Callable | None
Get batch collate function.
Source code in kit4dl/nn/dataset.py
303 304 305 |
|
get_train_collate_fn ¶
get_train_collate_fn() -> Callable | None
Get train specific collate function.
Source code in kit4dl/nn/dataset.py
307 308 309 |
|
get_val_collate_fn ¶
get_val_collate_fn() -> Callable | None
Get validation specific collate function.
Source code in kit4dl/nn/dataset.py
311 312 313 |
|
get_test_collate_fn ¶
get_test_collate_fn() -> Callable | None
Get test specific collate function.
Source code in kit4dl/nn/dataset.py
315 316 317 |
|
get_predict_collate_fn ¶
get_predict_collate_fn() -> Callable | None
Get predict specific collate function.
Source code in kit4dl/nn/dataset.py
319 320 321 |
|
trainval_dataloaders ¶
Prepare loader for train and validation data.
Source code in kit4dl/nn/dataset.py
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 |
|
test_dataloader ¶
test_dataloader() -> DataLoader
Prepare loader for test data.
Source code in kit4dl/nn/dataset.py
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 |
|
predict_dataloader ¶
predict_dataloader() -> DataLoader
Prepare loader for prediction data.
Source code in kit4dl/nn/dataset.py
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 |
|
numpy_train_dataloader ¶
numpy_train_dataloader()
Prepare loader for train data for models accepting numpy.ndarray
.
Source code in kit4dl/nn/dataset.py
429 430 431 |
|
numpy_val_dataloader ¶
numpy_val_dataloader()
Prepare loader for val data for models accepting numpy.ndarray
.
Source code in kit4dl/nn/dataset.py
433 434 435 |
|
numpy_testdataloader ¶
numpy_testdataloader()
Prepare loader for test data for models accepting numpy.ndarray
.
Source code in kit4dl/nn/dataset.py
437 438 439 |
|
numpy_predictdataloader ¶
numpy_predictdataloader()
Prepare loader for pred data for models accepting numpy.ndarray
.
Source code in kit4dl/nn/dataset.py
441 442 443 |
|
Custom splits¶
Note
Available since 2024.5b0
n Kit4DL, you can easily define the logic for cross-validation. Starting from the version 2024.5b0 the old method prepare_trainvaldataset
was replaced by the prepare_trainvaldatasets
method that is a generator. You define the logic of the generator by yourself. To run 10-fold cross validation, implement the method in the following way:
...
from sklearn.model_selection import KFold
class MNISTCustomDatamodule(Kit4DLAbstractDataModule):
def prepare_trainvaldatasets(self, root_dir: str):
dset = MNIST(
root=root_dir,
train=True,
download=True,
transform=transforms.ToTensor(),
)
split = KFold(n_splits=10, shuffle=True, random_state=0)
for i, (train_ind, val_ind) in enumerate(
split.split(dset.data, dset.targets)
):
yield Subset(dset, train_ind), Subset(dset, val_ind)
If you want to stick to the old logic and return a single split, just yield
the corresponding datasets:
...
class MNISTCustomDatamodule(Kit4DLAbstractDataModule):
def prepare_trainvaldatasets(self, root_dir: str):
tr_dset = MNIST(
root=root_dir,
train=True,
download=True,
transform=transforms.ToTensor(),
)
ts_dset = MNIST(
root=root_dir,
train=False,
download=True,
transform=transforms.ToTensor(),
)
yield tr_dset, ts_dset
Each generated tuple of train and validation dataset will be fed into the training/validation pipeline. If you use external metric loggers, results for each split will be uploaded using the experiment name and the suffix like (split=0)
.
The suffix can be overwriten by the environmental variable KIT4DL_SPLIT_PREFIX
.