persia.data
In PERSIA, we provide the DataLoader
class to load the data. The DataLoader
will preprocess
the PersiaBatch
and lookup the embedding for id_type_features. In order to initalize a DataLoader,
the dataset must be an ``iterable dataset” (an instance of :class`.IterableDatasetBase` subclass).
To generate an iterable dataset, you can use the StreamingDataset
to fetch the PersiaBatch
from the dataflow, or use the
IterableDataset
to generate the PersiaBatch
locally.
Module Contents
- class persia.data.DataLoader(dataset, forward_buffer_size=10, timeout_ms=1000 * 60 * 10, num_workers=10, reproducible=False, embedding_staleness=None)
Data loader will preprocess the data to the
PersiaTrainingBatch
.The
DataLoader
is a pipeline that preprocess thePersiaBatch
in several steps. Each step will process the task concurrently with multiple threads to improve the efficiency.Warning
The
DataLoader
cannot stop the iteration unless raise theTimeoutError
if you use theStreamingDataset
(see StreamingDataset for more details).- Parameters:
dataset (IterableDatasetBase) – dataset for DataLoader to retrive replica info and sender channel.
forward_buffer_size (int, optional) –
PersiaTrainingBatch
buffer size, this args effect the gpu memory cost.timeout_ms (int, optional) – timeout of data fetching, millisecond unit.
num_workers (int, optional) – number of spawned thread workers to lookup embedding and
PersiaBatch
prefetch.reproducible (bool, optional) – iterate the data in fixed order, make the dataflow deterministic.
embedding_staleness (int, optional) – max number of batched staleness embeddings each rank. A staleness embedding means it is prefetched from embedding server before gradient updated.
- class persia.data.IterableDataset(buffer_size=10)
Bases:
IterableDatasetBase
The
IterableDataset
can iterate through the dataset multiple times, whereas inStreamingDataset
the dataset is only iterated once. It is advised that you implement the TestDataset usingIterableDataset
.Implement the
__iter__
function to define thePersiaBatch
generation phase.import numpy as np from persia.data import IterableDataset, DataLoader from persia.embedding.data import PersiaBatch, IDTypeFeature class MyTestDataset(IterableDataset): def __init__(self): super(MyTestDataset, self).__init__() self.data = data self.size = 10 def __iter__(self): for i in range(self.size): persia_batch = PersiaBatch(id_type_features=IDTypeFeature( "id_type_feature_slot", [ np.array([1000, 10001], dtype=np.uint64), np.array([1003, 10011], dtype=np.uint64), ] ), requires_grad=False) yield persia_batch dataset = MyTestDataset() dataloader = DataLoader(dataset)
- Parameters:
buffer_size (int, optional) –
PersiaBatch
buffer size
- consume_dataset()
Consume
__iter__
of itself and return the iterator of preprocess indexes.- Return type:
Iterator[int]
- class persia.data.IterableDatasetBase(buffer_size=10)
Bases:
abc.ABC
,Iterable
[persia.embedding.data.PersiaBatch
]The role of IterableDatasetBase is to transfer the
PersiaBatch
to theDataLoader
. It wraps thePersiaBatchDataChannel
which provides the ability to send data toDataLoader
. It has a sender (PersiaBatchDataSender
) and a receiver (PersiaBatchDataSender
), whose functionalities are illustrated in the example below.This class cannot be used directly unless it implements
__iter__
andconsume_dataset
functions to be compatible with theDataLoader
.__iter__
function generates thePersiaBatch
, andconsume_dataset
sends thePersiaBatch
byPersiaBatchDataSender
.Here is an example that implements a synchronous
IterableDatasetBase
.from typing import Iterator import numpy as np from persia.data import IterableDataset from persia.embedding.data import PersiaBatch, IDTypeFeature class MyPersiaIterableDataset(IterableDatasetBase): def __iter__(self): persia_batch = PersiaBatch(id_type_features=IDTypeFeature( "id_type_feature_slot", [ np.array([1000, 10001], dtype=np.uint64), np.array([1003, 10011], dtype=np.uint64), ] ), requires_grad=False) yield persia_batch yield persia_batch def consume_data(self) -> Iterator[int]: for preprocess_idx, persia_batch in enumerate(self): self.sender.send(persia_batch) yield preprocess_idx
Note
MyPersiaIterableDataset implemented in the above example will be slow if you are dealing with a large dataset, since it processes the
PersiaBatch
synchronously. If you want to improve the performance of data processing, try to use theIterableDataset
orStreamingDataset
instead.- Parameters:
buffer_size (int, optional) – buffer size for
PersiaBatchDataChannel
.
- abstract consume_dataset()
Consume
__iter__
of itself and return the iterator of preprocess indexes.- Return type:
Iterator[int]
- class persia.data.StreamingDataset(buffer_size=10)
Bases:
IterableDatasetBase
Streaming dataset receives the
PersiaBatch
from the upstream data flow sent byDataCtx
.In the implemented
StreamingDataset.consume_dataset
, thePersiaBatchDataSender
instance binds to the RPC service that receives the data automatically. So it is not necessary to implements theWarning
StreamingDataset
will make theDataLoader
raise theTimeoutError
if the upstream data flow drained.- Parameters:
buffer_size (int, optional) –
PersiaBatchDataChannel
buffer size
- consume_dataset()
Consume
__iter__
of itself and return the iterator of preprocess indexes.- Return type:
Iterator[int]