persia.helper
Module Contents
- class persia.helper.PersiaServiceCtx(data_loader_func=None, nn_worker_func=None, embedding_config=None, embedding_config_path=None, global_config=None, global_config_path=None, data_loader_replica_num=1, nproc_per_node=1, embedding_worker_replica_num=1, embedding_parameter_replica_num=1, embedding_worker_port=7777, embedding_parameter_server_port=8888, nats_server_port=4222)
Launch the required processes to mock the distributed PERSIA environment.
You can add the embedding config file path or a dict that contains the embedding config information, where the embedding configuration will both apply to the embedding_worker and embedding_parameter_server. Similarly you can also set global config using either config file path or a dict.
For example, training the model with the
TrainCtx
under thePersiaServiceCtx
context,PersiaServiceCtx
will launch theDataCtx
in subprocess.import torch from persia.helper import PersiaServiceCtx from persia.ctx import DataCtx, TrainCtx from persia.data import DataLoader, StreamingDataset from persia.embedding import IDTypeFeature, PersiaBatch, Label from persia.embedding.optim import Adam embedding_config = {"slots_config": {"age": {"dim": 8}}} def data_loader_func(): import numpy as np with DataCtx() as data_ctx: for i in range(5): persia_batch = PersiaBatch( id_type_features=[ IDTypeFeature ], labels=[Label(np.array([0], dtype=np.float32))] ) data_ctx.send_data(persia_batch) with PersiaServiceCtx( embedding_config=embedding_config, data_loader_func=data_loader_func ): prefetch_buffer_size = 15 with TrainCtx( ... ) as ctx: data_loader = DataLoader(StreamingDataset(prefetch_buffer_size)) for persia_training_batch in data_loader: ...
- Parameters:
data_loader_func (Callable, optional) – data loader function that will be pickled and run on the individual process.
nn_worker_func (Callable, optional) – nn_worker function that will be pickled and run on the individual process.
embedding_config (dict, optional) – PERSIA embedding config, configuration reference.
embedding_config_path (str, optional) – PERSIA embedding config path.
global_config (dict, optional) –
PERSIA global config, configuration reference.
global_config_path (str, optional) – PERSIA global config path.
data_loader_replica_num (int, optional) – data_loader process number.
nproc_per_node (int, optional) – number of process for data parallel.
embedding_worker_replica_num (int, optional) – number of process for embedding_worker.
embedding_parameter_replica_num (Optional[int]) – (int, optional): number of process for embedding_parameter_server.
embedding_worker_port (int, optional) – port of embedding-worker server.
embedding_parameter_server_port (int, optional) – port of embedding-parameter-server.
nats_server_port (int, optional) – port of nats-server.
- persia.helper.ensure_persia_service(*args, **kwargs)
- Return type: