Skip to content

Specifying data module

Bases: ABC, LightningDataModule, LoggerMixin

The class with the logic for dataset management.

The class provides a user with a simple interface: prepare_data is a method for downloading and preprocessing datasets prepare_traindatasets is a method returning torch.Dataset for train data prepare_valdatasets is a method returning torch.Dataset for validation data prepare_trainvaldatasets is a method returning a tuple of two torch.Datasets for train and validation data (if this method is provided, prepare_traindatasets and prepare_valdatasets shouldn't be implemented) prepare_testdataset is a method returning torch.Dataset for test data prepare_predictdataset is a method returning torch.Dataset for data for prediction

Source code in kit4dl/nn/dataset.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(self, conf: DatasetConf):
    super().__init__()
    self.conf = conf
    self.trainval_dataset_generator: (
        Generator[tuple[Dataset, Dataset], None, None] | None
    ) = None
    self.train_dataset: Dataset | None = None
    self.val_dataset: Dataset | None = None
    self.test_dataset: Dataset | None = None
    self.predict_dataset: Dataset | None = None
    self._split_suffix = os.environ.get(
        "KIT4DL_SPLIT_PREFIX", "(split={:02d})"
    )
    self._configure_logger()
    for extra_arg_key, extra_arg_value in self.conf.arguments.items():
        self.debug(
            "setting extra user-defined argument: %s:%s",
            extra_arg_key,
            extra_arg_value,
        )
        setattr(self, extra_arg_key, extra_arg_value)

prepare_data

prepare_data()

Prepare dataset for train/validation/test/predict splits.

Examples

class MyDatamodule(Kit4DLAbstractDataModule):

    def prepare_data(self):
        # any logic you need to perform before creating splits
        download_dataset()
Source code in kit4dl/nn/dataset.py
67
68
69
70
71
72
73
74
75
76
77
78
79
def prepare_data(self):
    """Prepare dataset for train/validation/test/predict splits.

    Examples
    --------
    ```python
    class MyDatamodule(Kit4DLAbstractDataModule):

        def prepare_data(self):
            # any logic you need to perform before creating splits
            download_dataset()
    ```
    """

prepare_traindataset

prepare_traindataset(*args: Any, **kwargs: Any) -> Dataset

Prepare dataset for training.

Parameters

args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset

Returns

train_dataset : Dataset A training dataset

Examples

...
def prepare_traindatasets(self, root_dir: str) -> Dataset:
    train_dset = MyDataset(root_dir=root_dir)
    return train_dset
Source code in kit4dl/nn/dataset.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def prepare_traindataset(self, *args: Any, **kwargs: Any) -> Dataset:
    """Prepare dataset for training.

    Parameters
    ----------
    *args: Any
        List of positional arguments to setup the dataset
    **kwargs : Any
        List of named arguments required to setup the dataset

    Returns
    -------
    train_dataset : Dataset
        A training dataset

    Examples
    --------
    ```python
    ...
    def prepare_traindatasets(self, root_dir: str) -> Dataset:
        train_dset = MyDataset(root_dir=root_dir)
        return train_dset
    ```
    """
    raise NotImplementedError

prepare_valdataset

prepare_valdataset(*args: Any, **kwargs: Any) -> Dataset

Prepare dataset for validation.

Parameters

args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset

Returns

val_dataset : Dataset A validation dataset

Examples

...
def prepare_valdatasets(self, root_dir: str) -> Dataset:
    val_dset = MyDataset(root_dir=root_dir)
    return val_dset
Source code in kit4dl/nn/dataset.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def prepare_valdataset(self, *args: Any, **kwargs: Any) -> Dataset:
    """Prepare dataset for validation.

    Parameters
    ----------
    *args: Any
        List of positional arguments to setup the dataset
    **kwargs : Any
        List of named arguments required to setup the dataset

    Returns
    -------
    val_dataset : Dataset
        A validation dataset

    Examples
    --------
    ```python
    ...
    def prepare_valdatasets(self, root_dir: str) -> Dataset:
        val_dset = MyDataset(root_dir=root_dir)
        return val_dset
    ```
    """
    raise NotImplementedError

prepare_trainvaldatasets

prepare_trainvaldatasets(*args: Any, **kwargs: Any) -> Generator[tuple[Dataset, Dataset], None, None]

Prepare dataset for training and validation.

Parameters

args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset

Returns

trainval_dataset_generators : tuple of two Datasets and a string Tuple consisting of train and validation dataset, and a suffix for the split

Examples

...
def prepare_trainvaldatasets(self, root_dir: str) -> tuple[Dataset, Dataset]:
    dset = MyDataset(root_dir=root_dir)
    train_dset, val_dset = random_split(dset, [500, 50])
    return train_dset, val_dset
Source code in kit4dl/nn/dataset.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def prepare_trainvaldatasets(
    self, *args: Any, **kwargs: Any
) -> Generator[tuple[Dataset, Dataset], None, None]:
    """Prepare dataset for training and validation.

    Parameters
    ----------
    *args: Any
        List of positional arguments to setup the dataset
    **kwargs : Any
        List of named arguments required to setup the dataset

    Returns
    -------
    trainval_dataset_generators : tuple of two Datasets and a string
        Tuple consisting of train and validation dataset, and a suffix
        for the split

    Examples
    --------
    ```python
    ...
    def prepare_trainvaldatasets(self, root_dir: str) -> tuple[Dataset, Dataset]:
        dset = MyDataset(root_dir=root_dir)
        train_dset, val_dset = random_split(dset, [500, 50])
        return train_dset, val_dset
    ```
    """
    raise NotImplementedError

prepare_testdataset

prepare_testdataset(*args: Any, **kwargs: Any) -> Dataset

Prepare dataset for testing.

Parameters

args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset

Returns

test_datasets : Dataset A test dataset

Examples

...
def prepare_testdataset(self, root_dir: str) -> Dataset:
    return MyDataset(root_dir=root_dir)
Source code in kit4dl/nn/dataset.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def prepare_testdataset(self, *args: Any, **kwargs: Any) -> Dataset:
    """Prepare dataset for testing.

    Parameters
    ----------
    *args: Any
        List of positional arguments to setup the dataset
    **kwargs : Any
        List of named arguments required to setup the dataset

    Returns
    -------
    test_datasets : Dataset
        A test dataset

    Examples
    --------
    ```python
    ...
    def prepare_testdataset(self, root_dir: str) -> Dataset:
        return MyDataset(root_dir=root_dir)
    ```
    """
    raise NotImplementedError

prepare_predictdataset

prepare_predictdataset(*args: Any, **kwargs: Any) -> Dataset

Prepare dataset for predicting.

Parameters

args: Any List of positional arguments to setup the dataset *kwargs : Any List of named arguments required to setup the dataset

Returns

pred_datasets : Dataset A prediction dataset

Examples

...
def prepare_predictdataset(self, root_dir: str) -> Dataset:
    return MyDataset(root_dir=root_dir)
Source code in kit4dl/nn/dataset.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def prepare_predictdataset(self, *args: Any, **kwargs: Any) -> Dataset:
    """Prepare dataset for predicting.

    Parameters
    ----------
    *args: Any
        List of positional arguments to setup the dataset
    **kwargs : Any
        List of named arguments required to setup the dataset

    Returns
    -------
    pred_datasets : Dataset
        A prediction dataset

    Examples
    --------
    ```python
    ...
    def prepare_predictdataset(self, root_dir: str) -> Dataset:
        return MyDataset(root_dir=root_dir)
    ```
    """
    raise NotImplementedError

setup

setup(stage: str) -> None

Set up data depending on the stage.

The method should not be overriden unless necessary.

Parameters

stage : str The stage of the pipeline. One out of ['fit', 'test', 'predict']

Source code in kit4dl/nn/dataset.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def setup(self, stage: str) -> None:
    """Set up data depending on the stage.

    The method should not be overriden unless necessary.

    Parameters
    ----------
    stage : str
        The stage of the pipeline. One out of `['fit', 'test', 'predict']`
    """
    match stage:
        case "fit":
            self._handle_fit_stage()
        case "test":
            self._handle_test_stage()
        case "predict":
            self._handle_predict_stage()

get_collate_fn

get_collate_fn() -> Callable | None

Get batch collate function.

Source code in kit4dl/nn/dataset.py
303
304
305
def get_collate_fn(self) -> Callable | None:
    """Get batch collate function."""
    return None

get_train_collate_fn

get_train_collate_fn() -> Callable | None

Get train specific collate function.

Source code in kit4dl/nn/dataset.py
307
308
309
def get_train_collate_fn(self) -> Callable | None:
    """Get train specific collate function."""
    return self.get_collate_fn()

get_val_collate_fn

get_val_collate_fn() -> Callable | None

Get validation specific collate function.

Source code in kit4dl/nn/dataset.py
311
312
313
def get_val_collate_fn(self) -> Callable | None:
    """Get validation specific collate function."""
    return self.get_collate_fn()

get_test_collate_fn

get_test_collate_fn() -> Callable | None

Get test specific collate function.

Source code in kit4dl/nn/dataset.py
315
316
317
def get_test_collate_fn(self) -> Callable | None:
    """Get test specific collate function."""
    return self.get_collate_fn()

get_predict_collate_fn

get_predict_collate_fn() -> Callable | None

Get predict specific collate function.

Source code in kit4dl/nn/dataset.py
319
320
321
def get_predict_collate_fn(self) -> Callable | None:
    """Get predict specific collate function."""
    return self.get_collate_fn()

trainval_dataloaders

trainval_dataloaders() -> Generator[tuple[DataLoader, DataLoader, str], None, None]

Prepare loader for train and validation data.

Source code in kit4dl/nn/dataset.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def trainval_dataloaders(
    self,
) -> Generator[tuple[DataLoader, DataLoader, str], None, None]:
    """Prepare loader for train and validation data."""
    if self.conf.trainval:
        assert self.trainval_dataset_generator is not None, (
            "did you forget to return `torch.utils.data.Dataset`instance"
            " from the `prepare_trainvaldatasets` method?"
        )
        assert self.conf.train is not None, (
            "[dataset.train.loader] section is is missing in the"
            " configuration file."
        )
        assert self.conf.validation is not None, (
            "[dataset.validation.loader] section is is missing in the"
            " configuration file."
        )
        assert (
            self.conf.train.loader is not None
            and self.conf.validation.loader is not None
        ), (
            "did you forget to define [dataset.train.loader] and"
            "[dataset.validation.loader] sections in the configuration "
            "file?"
        )
        for i, (tr_dataset, val_dataset) in enumerate(
            self.trainval_dataset_generator
        ):
            yield DataLoader(
                tr_dataset,
                **self.conf.train.loader,
                collate_fn=self.get_train_collate_fn(),
            ), DataLoader(
                val_dataset,
                **self.conf.validation.loader,
                collate_fn=self.get_val_collate_fn(),
            ), self._split_suffix.format(
                i + 1
            )
    elif self.conf.train and self.conf.validation:
        yield self._train_dataloader(), self._val_dataloader(), ""

test_dataloader

test_dataloader() -> DataLoader

Prepare loader for test data.

Source code in kit4dl/nn/dataset.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def test_dataloader(self) -> DataLoader:
    """Prepare loader for test data."""
    assert self.conf.test, (
        "test configuration is not defined. did you forget"
        " [dataset.test] section in the configuration file?"
    )
    assert self.test_dataset is not None, (
        "did you forget to return `torch.utils.data.Dataset` instance"
        " from the `prepare_testdataset` method?"
    )
    return DataLoader(
        self.test_dataset,
        **self.conf.test.loader,
        collate_fn=self.get_test_collate_fn(),
    )

predict_dataloader

predict_dataloader() -> DataLoader

Prepare loader for prediction data.

Source code in kit4dl/nn/dataset.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def predict_dataloader(self) -> DataLoader:
    """Prepare loader for prediction data."""
    assert self.conf.predict, (
        "validation configuration is not defined. did you forget"
        " [dataset.predict] section in the configuration file?"
    )
    assert self.predict_dataset is not None, (
        "did you forget to return `torch.utils.data.Dataset` instance"
        " from the `prepare_predictdataset` method?"
    )
    return DataLoader(
        self.predict_dataset,
        **self.conf.predict.loader,
        collate_fn=self.get_predict_collate_fn(),
    )

numpy_train_dataloader

numpy_train_dataloader()

Prepare loader for train data for models accepting numpy.ndarray.

Source code in kit4dl/nn/dataset.py
429
430
431
def numpy_train_dataloader(self):
    """Prepare loader for train data for models accepting `numpy.ndarray`."""
    raise NotImplementedError

numpy_val_dataloader

numpy_val_dataloader()

Prepare loader for val data for models accepting numpy.ndarray.

Source code in kit4dl/nn/dataset.py
433
434
435
def numpy_val_dataloader(self):
    """Prepare loader for val data for models accepting `numpy.ndarray`."""
    raise NotImplementedError

numpy_testdataloader

numpy_testdataloader()

Prepare loader for test data for models accepting numpy.ndarray.

Source code in kit4dl/nn/dataset.py
437
438
439
def numpy_testdataloader(self):
    """Prepare loader for test data for models accepting `numpy.ndarray`."""
    raise NotImplementedError

numpy_predictdataloader

numpy_predictdataloader()

Prepare loader for pred data for models accepting numpy.ndarray.

Source code in kit4dl/nn/dataset.py
441
442
443
def numpy_predictdataloader(self):
    """Prepare loader for pred data for models accepting `numpy.ndarray`."""
    raise NotImplementedError

Custom splits

Note

Available since 2024.5b0

n Kit4DL, you can easily define the logic for cross-validation. Starting from the version 2024.5b0 the old method prepare_trainvaldataset was replaced by the prepare_trainvaldatasets method that is a generator. You define the logic of the generator by yourself. To run 10-fold cross validation, implement the method in the following way:

...
from sklearn.model_selection import KFold

class MNISTCustomDatamodule(Kit4DLAbstractDataModule):
    def prepare_trainvaldatasets(self, root_dir: str):
        dset = MNIST(
            root=root_dir,
            train=True,
            download=True,
            transform=transforms.ToTensor(),
        )
        split = KFold(n_splits=10, shuffle=True, random_state=0)
        for i, (train_ind, val_ind) in enumerate(
            split.split(dset.data, dset.targets)
        ):
            yield Subset(dset, train_ind), Subset(dset, val_ind)

If you want to stick to the old logic and return a single split, just yield the corresponding datasets:

...

class MNISTCustomDatamodule(Kit4DLAbstractDataModule):
    def prepare_trainvaldatasets(self, root_dir: str):
        tr_dset = MNIST(
            root=root_dir,
            train=True,
            download=True,
            transform=transforms.ToTensor(),
        )
        ts_dset = MNIST(
            root=root_dir,
            train=False,
            download=True,
            transform=transforms.ToTensor(),
        )        
        yield tr_dset, ts_dset

Each generated tuple of train and validation dataset will be fed into the training/validation pipeline. If you use external metric loggers, results for each split will be uploaded using the experiment name and the suffix like (split=0).

The suffix can be overwriten by the environmental variable KIT4DL_SPLIT_PREFIX.