persia.data
In PERSIA, we provide the DataLoader class to load the data. The DataLoader will preprocess
the PersiaBatch and lookup the embedding for id_type_features. In order to initalize a DataLoader,
the dataset must be an ``iterable dataset” (an instance of :class`.IterableDatasetBase` subclass).
To generate an iterable dataset, you can use the StreamingDataset to fetch the PersiaBatch from the dataflow, or use the
IterableDataset to generate the PersiaBatch locally.
Module Contents
- class persia.data.DataLoader(dataset, forward_buffer_size=10, timeout_ms=1000 * 60 * 10, num_workers=10, reproducible=False, embedding_staleness=None)
Data loader will preprocess the data to the
PersiaTrainingBatch.The
DataLoaderis a pipeline that preprocess thePersiaBatchin several steps. Each step will process the task concurrently with multiple threads to improve the efficiency.Warning
The
DataLoadercannot stop the iteration unless raise theTimeoutErrorif you use theStreamingDataset(see StreamingDataset for more details).- Parameters:
dataset (IterableDatasetBase) – dataset for DataLoader to retrive replica info and sender channel.
forward_buffer_size (int, optional) –
PersiaTrainingBatchbuffer size, this args effect the gpu memory cost.timeout_ms (int, optional) – timeout of data fetching, millisecond unit.
num_workers (int, optional) – number of spawned thread workers to lookup embedding and
PersiaBatchprefetch.reproducible (bool, optional) – iterate the data in fixed order, make the dataflow deterministic.
embedding_staleness (int, optional) – max number of batched staleness embeddings each rank. A staleness embedding means it is prefetched from embedding server before gradient updated.
- class persia.data.IterableDataset(buffer_size=10)
Bases:
IterableDatasetBaseThe
IterableDatasetcan iterate through the dataset multiple times, whereas inStreamingDatasetthe dataset is only iterated once. It is advised that you implement the TestDataset usingIterableDataset.Implement the
__iter__function to define thePersiaBatchgeneration phase.import numpy as np from persia.data import IterableDataset, DataLoader from persia.embedding.data import PersiaBatch, IDTypeFeature class MyTestDataset(IterableDataset): def __init__(self): super(MyTestDataset, self).__init__() self.data = data self.size = 10 def __iter__(self): for i in range(self.size): persia_batch = PersiaBatch(id_type_features=IDTypeFeature( "id_type_feature_slot", [ np.array([1000, 10001], dtype=np.uint64), np.array([1003, 10011], dtype=np.uint64), ] ), requires_grad=False) yield persia_batch dataset = MyTestDataset() dataloader = DataLoader(dataset)
- Parameters:
buffer_size (int, optional) –
PersiaBatchbuffer size
- consume_dataset()
Consume
__iter__of itself and return the iterator of preprocess indexes.- Return type:
Iterator[int]
- class persia.data.IterableDatasetBase(buffer_size=10)
Bases:
abc.ABC,Iterable[persia.embedding.data.PersiaBatch]The role of IterableDatasetBase is to transfer the
PersiaBatchto theDataLoader. It wraps thePersiaBatchDataChannelwhich provides the ability to send data toDataLoader. It has a sender (PersiaBatchDataSender) and a receiver (PersiaBatchDataSender), whose functionalities are illustrated in the example below.This class cannot be used directly unless it implements
__iter__andconsume_datasetfunctions to be compatible with theDataLoader.__iter__function generates thePersiaBatch, andconsume_datasetsends thePersiaBatchbyPersiaBatchDataSender.Here is an example that implements a synchronous
IterableDatasetBase.from typing import Iterator import numpy as np from persia.data import IterableDataset from persia.embedding.data import PersiaBatch, IDTypeFeature class MyPersiaIterableDataset(IterableDatasetBase): def __iter__(self): persia_batch = PersiaBatch(id_type_features=IDTypeFeature( "id_type_feature_slot", [ np.array([1000, 10001], dtype=np.uint64), np.array([1003, 10011], dtype=np.uint64), ] ), requires_grad=False) yield persia_batch yield persia_batch def consume_data(self) -> Iterator[int]: for preprocess_idx, persia_batch in enumerate(self): self.sender.send(persia_batch) yield preprocess_idx
Note
MyPersiaIterableDataset implemented in the above example will be slow if you are dealing with a large dataset, since it processes the
PersiaBatchsynchronously. If you want to improve the performance of data processing, try to use theIterableDatasetorStreamingDatasetinstead.- Parameters:
buffer_size (int, optional) – buffer size for
PersiaBatchDataChannel.
- abstract consume_dataset()
Consume
__iter__of itself and return the iterator of preprocess indexes.- Return type:
Iterator[int]
- class persia.data.StreamingDataset(buffer_size=10)
Bases:
IterableDatasetBaseStreaming dataset receives the
PersiaBatchfrom the upstream data flow sent byDataCtx.In the implemented
StreamingDataset.consume_dataset, thePersiaBatchDataSenderinstance binds to the RPC service that receives the data automatically. So it is not necessary to implements theWarning
StreamingDatasetwill make theDataLoaderraise theTimeoutErrorif the upstream data flow drained.- Parameters:
buffer_size (int, optional) –
PersiaBatchDataChannelbuffer size
- consume_dataset()
Consume
__iter__of itself and return the iterator of preprocess indexes.- Return type:
Iterator[int]