persia.ctx
Module Contents
- class persia.ctx.BaseCtx(threadpool_worker_size=10, device_id=None)
Initializes a common context for other persia context, e.g.
DataCtx
,EmbeddingCtx
andTrainCtx
. This class should not be instantiated directly.- Parameters:
threadpool_worker_size (int) – rpc threadpool worker size.
device_id (int, optional) – the CUDA device to use for this process.
- class persia.ctx.DataCtx(*args, **kwargs)
Bases:
BaseCtx
Data context provides the communication functionality to data generator component. Used for sending a
PersiaBatch
to the nn worker and embedding worker.If you use the
DataCtx
to send thePersiaBatch
on data-loader, you should use theStreamingDataset
to receive the data on nn-worker.On data-loader:
from persia.ctx import DataCtx from persia.embedding.data import PersiaBatch loader = make_loader() with DataCtx() as ctx: for (non_id_type_features, id_type_features, labels) in loader: batch_data = PersiaBatch( id_type_features=id_type_features, non_id_type_features, label, requires_grad=True ) ctx.send_data(persia_batch)
On nn-worker:
from persia.ctx import TrainCtx from persia.data import StreamingDataset, DataLoader buffer_size = 15 streaming_dataset = StreamingDataset(buffer_size) data_loader = DataLoader(streaming_dataset) with TrainCtx(...): for persia_training_batch in data_loader: ...
Note
The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result.
- Parameters:
threadpool_worker_size (int) – rpc threadpool worker size.
device_id (int, optional) – the CUDA device to use for this process.
- send_data(persia_batch)
Send PersiaBatch from data loader to nn worker and embedding worker side.
- Parameters:
persia_batch (PersiaBatch) –
PersiaBatch
that haven’t been processed.
- class persia.ctx.EmbeddingCtx(preprocess_mode, model=None, embedding_config=None, *args, **kwargs)
Bases:
BaseCtx
Provides the embedding-related functionality.
EmbeddingCtx
can run offline test or online inference depending on different preprocess_mode. The simplest way to get this context is by usingeval_ctx
to get theEmbeddingCtx
instance.Example for
EmbeddingCtx
:from persia.ctx import EmbeddingCtx, PreprocessMode from persia.embedding.data import PersiaBatch model = get_dnn_model() loader = make_dataloader() device_id = 0 with EmbeddingCtx( PreprocessMode.EVAL, model=model, device_id=device_id ) as ctx: for (non_id_type_features, id_type_features, labels) in loader: persia_batch = PersiaBatch( id_type_features non_id_type_features=non_id_type_features, labels=labels requires_grad=False ) persia_training_batch = ctx.get_embedding_from_data(persia_batch) (output, label) = ctx.forward(persia_training_batch)
Note
The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result.
Note
If you set
device_id=None
, the training data and the model will be placed in host memory rather than in CUDA device memory by default.- Parameters:
preprocess_mode (PreprocessMode) – different preprocess mode effect the behavior of
prepare_features
.model (torch.nn.Module) – denese neural network PyTorch model.
embedding_config (EmbeddingConfig, optional) – the embedding configuration that will be sent to the embedding server.
- clear_embeddings()
Clear all embeddings on all embedding servers.
- configure_embedding_parameter_servers(embedding_config)
Apply
EmbeddingConfig
to embedding servers.- Parameters:
embedding_config (EmbeddingConfig) – the embedding configuration that will be sent to the embedding server.
- dump_checkpoint(dst_dir, dense_filename='dense.pt', jit_dense_filename='jit_dense.pt', blocking=True, with_jit_model=False)
Save the model checkpoint (both dense and embedding) to the destination directory.
- Parameters:
dst_dir (str) – destination directory.
dense_filename (str, optional) – dense checkpoint filename.
jit_dense_filename (str, optional) – dense checkpoint filename after PyTorch jit script.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
with_jit_model (bool, optional) – dump jit script dense checkpoint or not.
- dump_embedding(dst_dir, blocking=True)
Dump embeddings to the destination directory. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call to
wait_for_dump_embedding
. Setblocking=False
to allow asyncronous computation, in which case the function will return immediately.wait_for_dump_embedding
to wait until finished ifblocking=False
.- Parameters:
dst_dir (str) – destination directory.
blocking (bool, optional) – dump embedding in blocking mode or not.
- dump_torch_state_dict(torch_instance, dst_dir, file_name, is_jit=False)
Dump a Pytorch model or optimizer’s state dict to the destination directory.
- Parameters:
torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to be dumped.
dst_dir (str) – destination directory.
file_name (str) – destination filename.
is_jit (bool, optional) – whether to dump model as jit script.
- forward(batch)
Call
prepare_features
and then do a forward step of the model in context.- Parameters:
batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features ,labels, id_type_feature_embeddings and meta info.
- Returns:
the tuple of output data and target data.
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- get_embedding_from_bytes(data, device_id=None)
Get embeddings of the serialized input batch data.
- Parameters:
data (PersiaBatch) – serialized input data without embeddings.
device_id (int, optional) – the CUDA device to use for this process.
- Returns:
PersiaTrainingBatch that contains id_type_feature_embeddings.
- Return type:
persia.prelude.PersiaTrainingBatch
- get_embedding_from_data(persia_batch, device_id=None)
Get embeddings of the serialized input batch data.
- Parameters:
persia_batch (PersiaBatch) – input data without embeddings..
device_id (int, optional) – the CUDA device to use for this process.
- Returns:
PersiaTrainingBatch that contains id_type_feature_embeddings.
- Return type:
persia.prelude.PersiaTrainingBatch
- get_embedding_size()
Get number of ids on all embedding servers.
- Return type:
List[int]
- load_checkpoint(src_dir, map_location=None, dense_filename='dense.pt', blocking=True)
Load the dense and embedding checkpoint from the source directory.
- Parameters:
src_dir (str) – source directory.
map_location (str, optional) – load the dense checkpoint to specific device.
dense_filename (str, optional) – dense checkpoint filename.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
- load_embedding(src_dir, blocking=True)
Load embeddings from
src_dir
. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call towait_for_load_embedding
. Setblocking=False
to allow asyncronous computation, in which case the function will return immediately.- Parameters:
src_dir (str) – directory to load embeddings.
blocking (bool, optional) – dump embedding in blocking mode or not.
- load_torch_state_dict(torch_instance, src_dir, map_location=None)
Load a Pytorch state dict from the source directory and apply to torch_instance.
- Parameters:
torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to restore.
src_dir (str) – directory to load torch state dict.
map_location (str, optional) – load the dense checkpoint to specific device.
- prepare_features(persia_training_batch)
This function converts data from
PersiaTrainingBatch
totorch.Tensor
.PersiaTrainingBatch
contains non_id_type_features, id_type_feature_embeddings and labels. But they can’t use directly in training before convert theTensor
totorch.Tensor
.- Parameters:
persia_training_batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features, labels, id_type_feature_embeddings and meta info.
- Returns:
the tuple of non_id_type_feature_tensors, id_type_feature_embedding_tensors and label_tensors.
- Return type:
Tuple[List[torch.Tensor], List[torch.Tensor], Optional[List[torch.Tensor]]]
- wait_for_dump_embedding()
Wait for the embedding dump process.
- wait_for_load_embedding()
Wait for the embedding load process.
- class persia.ctx.InferCtx(embedding_worker_address_list, *args, **kwargs)
Bases:
EmbeddingCtx
Subclass of
EmbeddingCtx
that provides the inference functionality without nats-servers.Example for
InferCtx
:import numpy as np from persia.ctx import InferCtx from persia.embedding.data import PersiaBatch, IDTypeFeatureWithSingleID device_id = 0 id_type_feature = IDTypeFeatureWithSingleID( "id_type_feature", np.array([1, 2, 3], np.uint64) ) persia_batch = PersiaBatch([id_type_feature], requires_grad=False) embedding_worker_address_list = [ "localhost: 8888", "localhost: 8889", "localhost: 8890" ] with InferCtx(embedding_worker_address_list, device_id=device_id) as infer_ctx: persia_training_batch = persia_context.get_embedding_from_bytes( persia_batch.to_bytes(), ) ( non_id_type_feature_tensors, id_type_feature_embedding_tensors, label_tensors )= persia_context.prepare_features(batch)
Note
The example cannot be run directly, you should launch the embedding-worker and embedding-parameter-server to ensure the example gets correct result.
- Parameters:
embedding_worker_addrs (List[str]) – embedding worker address(ip:port) list.
embedding_worker_address_list (List[str]) –
- wait_for_serving()
- class persia.ctx.PreprocessMode
Bases:
enum.Enum
Mode of preprocessing.
Used by
prepare_features
to generate features of different datatypes.When set to
TRAIN
,prepare_features
will return a torch tensor withrequires_grad
attribute set toTrue
. When set toEVAL
,prepare_features
will return a torch tensor withrequires_grad
attribute set toFalse
.INFERENCE
behaves almost identical toEVAL
, except thatINFERENCE
allowsEmbeddingCtx
to process thePersiaTrainingBatch
without a target tensor.- EVAL = 2
- INFERENCE = 3
- TRAIN = 1
- class persia.ctx.TrainCtx(embedding_optimizer, dense_optimizer, grad_scalar_update_factor=4, backward_buffer_size=10, backward_workers_size=8, grad_update_buffer_size=60, lookup_emb_directly=True, mixed_precision=True, distributed_option=None, *args, **kwargs)
Bases:
EmbeddingCtx
Subclass of
EmbeddingCtx
that implements a backward function to update the embeddings.Example for
TrainCtx
:import torch import persia from persia.data import DataLoder, StreamingDataset device_id = 0 model = get_dnn_model() model.cuda(device_id) embedding_optimizer = persia.embedding.optim.SGD(lr=1e-3) dense_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.BCELoss(reduction="mean") prefetch_size = 15 stream_dataset = StreamingDataset(prefetch_size) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, device_id=device_id ) as ctx: dataloader = DataLoder(stream_dataset) for persia_training_batch in datalaoder: output, labels = ctx.forward(persia_training_batch) loss = loss_fn(output, labels[0]) scaled_loss = ctx.backward(loss)
If you want to train the PERSIA task in a distributed environment, you can set distributed_option to the corresponding option you want to use. Currently support Pytorch DDP (distributed data-parallel) (
DDPOption
) and Bagua (BaguaDistributedOption
). The default is Pytorch DDP. The default configuration is determined byget_default_distributed_option
when the environmentWORLD_SIZE > 1
.You can configure the
DDPOption
to your specific requirements.import persia from persia.distributed import DDPOption backend = "nccl" # backend = "gloo" # If you want to train the PERSIA on the CPU cluster. ddp_option = DDPOption( backend=backend, init_method="tcp" ) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, distributed_option=ddp_option ) as ctx: ...
We also integrated Bagua to PERSIA as an alternative to PytorchDDP. Bagua is an advanced data-parallel framework, also developed by AI Platform @ Kuaishou. Using
BaguaDistributedOption
in place ofDDPOption
can significantly speed up the training (See Bagua Benchmark). For more details on the algorithms used by and available options ofBaguaDistributedOption
, please refer to Bagua tutorials.Example for
BaguaDistributedOption
:from persia.distributed import BaguaDistributedOption algorithm = "gradient_allreduce" bagua_args = {} bagua_option = BaguaDistributedOption( algorithm, **bagua_args ) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, distributed_option=bagua_option ) as ctx: ...
- Parameters:
embedding_optimizer (persia.embedding.optim.Optimizer) – optimizer for the embedding parameters.
dense_optimizer (torch.optim.Optimizer) – optimizer for dense parameters.
grad_scalar_update_factor (float, optional) – update factor of
Gradscalar
to ensure that loss scale is finite if setmixed_precision=True
.backward_buffer_size (int, optional) – maximum number of gradients queued in the buffer between two backward steps.
backward_workers_size (int, optional) – number of workers sending embedding gradients in parallel.
grad_update_buffer_size (int, optional) – the size of gradient buffers. The buffer will cache the gradient tensor until the embedding update is finished.
lookup_emb_directly (bool, optional) – lookup embedding directly without a separate data loader.
mixed_precision (bool) – whether to enable mixed_precision.
distributed_option (DistributedBaseOption, optional) – option for distributed training.
- backward(loss, embedding_gradient_check_frequency=20)
Update the parameters of the current dense model and embedding model.
- Parameters:
loss (torch.Tensor) – loss of current batch.
embedding_gradient_check_frequency (int, optional) – how many batch_size to check gradient finite or not for current embedding.
- Return type:
torch.Tensor
- dump_checkpoint(dst_dir, dense_model_filename='dense.pt', jit_dense_model_filename='jit_dense.pt', opt_filename='opt.pt', blocking=True, with_jit_model=False)
Dump the dense and embedding checkpoint to destination directory.
- Parameters:
dst_dir (str) – destination directory.
dense_model_filename (str, optional) – dense model checkpoint filename.
jit_dense_model_filename (str, optional) – dense checkpoint filename after PyTorch jit.
opt_filename (str, optional) – optimizer checkpoint filename.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
with_jit_model (bool, optional) – dump dense checkpoint as jit script or not.
- load_checkpoint(src_dir, map_location=None, dense_model_filename='dense.pt', opt_filename='opt.pt', blocking=True)
Load the dense and embedding checkpoint from source directory.
- Parameters:
src_dir (str) – source directory.
map_location (str, optional) – load the dense checkpoint to specific device.
dense_model_filename (str, optional) – dense checkpoint filename.
opt_filename (str, optional) – optimizer checkpoint filename.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
- wait_servers_ready()
Wait until embedding servers are ready to serve.
- persia.ctx.eval_ctx(*args, **kwargs)
Get the
EmbeddingCtx
with theEVAL
mode.- Return type: