persia.helper

Module Contents

class persia.helper.PersiaServiceCtx(data_loader_func=None, nn_worker_func=None, embedding_config=None, embedding_config_path=None, global_config=None, global_config_path=None, data_loader_replica_num=1, nproc_per_node=1, embedding_worker_replica_num=1, embedding_parameter_replica_num=1, embedding_worker_port=7777, embedding_parameter_server_port=8888, nats_server_port=4222)

Launch the required processes to mock the distributed PERSIA environment.

You can add the embedding config file path or a dict that contains the embedding config information, where the embedding configuration will both apply to the embedding_worker and embedding_parameter_server. Similarly you can also set global config using either config file path or a dict.

For example, training the model with the TrainCtx under the PersiaServiceCtx context, PersiaServiceCtx will launch the DataCtx in subprocess.

import torch

from persia.helper import PersiaServiceCtx
from persia.ctx import DataCtx, TrainCtx
from persia.data import DataLoader, StreamingDataset
from persia.embedding import IDTypeFeature, PersiaBatch, Label
from persia.embedding.optim import Adam

embedding_config = {"slots_config": {"age": {"dim": 8}}}

def data_loader_func():
    import numpy as np

    with DataCtx() as data_ctx:
        for i in range(5):
            persia_batch = PersiaBatch(
                id_type_features=[
                    IDTypeFeature
                ],
                labels=[Label(np.array([0], dtype=np.float32))]
            )
            data_ctx.send_data(persia_batch)


with PersiaServiceCtx(
    embedding_config=embedding_config,
    data_loader_func=data_loader_func
):
    prefetch_buffer_size = 15
    with TrainCtx(
        ...
    ) as ctx:
        data_loader = DataLoader(StreamingDataset(prefetch_buffer_size))
        for persia_training_batch in data_loader:
            ...
Parameters:
  • data_loader_func (Callable, optional) – data loader function that will be pickled and run on the individual process.

  • nn_worker_func (Callable, optional) – nn_worker function that will be pickled and run on the individual process.

  • embedding_config (dict, optional) – PERSIA embedding config, configuration reference.

  • embedding_config_path (str, optional) – PERSIA embedding config path.

  • global_config (dict, optional) –

    PERSIA global config, configuration reference.

  • global_config_path (str, optional) – PERSIA global config path.

  • data_loader_replica_num (int, optional) – data_loader process number.

  • nproc_per_node (int, optional) – number of process for data parallel.

  • embedding_worker_replica_num (int, optional) – number of process for embedding_worker.

  • embedding_parameter_replica_num (Optional[int]) – (int, optional): number of process for embedding_parameter_server.

  • embedding_worker_port (int, optional) – port of embedding-worker server.

  • embedding_parameter_server_port (int, optional) – port of embedding-parameter-server.

  • nats_server_port (int, optional) – port of nats-server.

persia.helper.ensure_persia_service(*args, **kwargs)
Return type:

PersiaServiceCtx