persia.ctx
Module Contents
- class persia.ctx.BaseCtx(threadpool_worker_size=10, device_id=None)
Initializes a common context for other persia context, e.g.
DataCtx,EmbeddingCtxandTrainCtx. This class should not be instantiated directly.- Parameters:
threadpool_worker_size (int) – rpc threadpool worker size.
device_id (int, optional) – the CUDA device to use for this process.
- class persia.ctx.DataCtx(*args, **kwargs)
Bases:
BaseCtxData context provides the communication functionality to data generator component. Used for sending a
PersiaBatchto the nn worker and embedding worker.If you use the
DataCtxto send thePersiaBatchon data-loader, you should use theStreamingDatasetto receive the data on nn-worker.On data-loader:
from persia.ctx import DataCtx from persia.embedding.data import PersiaBatch loader = make_loader() with DataCtx() as ctx: for (non_id_type_features, id_type_features, labels) in loader: batch_data = PersiaBatch( id_type_features=id_type_features, non_id_type_features, label, requires_grad=True ) ctx.send_data(persia_batch)
On nn-worker:
from persia.ctx import TrainCtx from persia.data import StreamingDataset, DataLoader buffer_size = 15 streaming_dataset = StreamingDataset(buffer_size) data_loader = DataLoader(streaming_dataset) with TrainCtx(...): for persia_training_batch in data_loader: ...
Note
The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result.
- Parameters:
threadpool_worker_size (int) – rpc threadpool worker size.
device_id (int, optional) – the CUDA device to use for this process.
- send_data(persia_batch)
Send PersiaBatch from data loader to nn worker and embedding worker side.
- Parameters:
persia_batch (PersiaBatch) –
PersiaBatchthat haven’t been processed.
- class persia.ctx.EmbeddingCtx(preprocess_mode, model=None, embedding_config=None, *args, **kwargs)
Bases:
BaseCtxProvides the embedding-related functionality.
EmbeddingCtxcan run offline test or online inference depending on different preprocess_mode. The simplest way to get this context is by usingeval_ctxto get theEmbeddingCtxinstance.Example for
EmbeddingCtx:from persia.ctx import EmbeddingCtx, PreprocessMode from persia.embedding.data import PersiaBatch model = get_dnn_model() loader = make_dataloader() device_id = 0 with EmbeddingCtx( PreprocessMode.EVAL, model=model, device_id=device_id ) as ctx: for (non_id_type_features, id_type_features, labels) in loader: persia_batch = PersiaBatch( id_type_features non_id_type_features=non_id_type_features, labels=labels requires_grad=False ) persia_training_batch = ctx.get_embedding_from_data(persia_batch) (output, label) = ctx.forward(persia_training_batch)
Note
The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result.
Note
If you set
device_id=None, the training data and the model will be placed in host memory rather than in CUDA device memory by default.- Parameters:
preprocess_mode (PreprocessMode) – different preprocess mode effect the behavior of
prepare_features.model (torch.nn.Module) – denese neural network PyTorch model.
embedding_config (EmbeddingConfig, optional) – the embedding configuration that will be sent to the embedding server.
- clear_embeddings()
Clear all embeddings on all embedding servers.
- configure_embedding_parameter_servers(embedding_config)
Apply
EmbeddingConfigto embedding servers.- Parameters:
embedding_config (EmbeddingConfig) – the embedding configuration that will be sent to the embedding server.
- dump_checkpoint(dst_dir, dense_filename='dense.pt', jit_dense_filename='jit_dense.pt', blocking=True, with_jit_model=False)
Save the model checkpoint (both dense and embedding) to the destination directory.
- Parameters:
dst_dir (str) – destination directory.
dense_filename (str, optional) – dense checkpoint filename.
jit_dense_filename (str, optional) – dense checkpoint filename after PyTorch jit script.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
with_jit_model (bool, optional) – dump jit script dense checkpoint or not.
- dump_embedding(dst_dir, blocking=True)
Dump embeddings to the destination directory. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call to
wait_for_dump_embedding. Setblocking=Falseto allow asyncronous computation, in which case the function will return immediately.wait_for_dump_embeddingto wait until finished ifblocking=False.- Parameters:
dst_dir (str) – destination directory.
blocking (bool, optional) – dump embedding in blocking mode or not.
- dump_torch_state_dict(torch_instance, dst_dir, file_name, is_jit=False)
Dump a Pytorch model or optimizer’s state dict to the destination directory.
- Parameters:
torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to be dumped.
dst_dir (str) – destination directory.
file_name (str) – destination filename.
is_jit (bool, optional) – whether to dump model as jit script.
- forward(batch)
Call
prepare_featuresand then do a forward step of the model in context.- Parameters:
batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features ,labels, id_type_feature_embeddings and meta info.
- Returns:
the tuple of output data and target data.
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- get_embedding_from_bytes(data, device_id=None)
Get embeddings of the serialized input batch data.
- Parameters:
data (PersiaBatch) – serialized input data without embeddings.
device_id (int, optional) – the CUDA device to use for this process.
- Returns:
PersiaTrainingBatch that contains id_type_feature_embeddings.
- Return type:
persia.prelude.PersiaTrainingBatch
- get_embedding_from_data(persia_batch, device_id=None)
Get embeddings of the serialized input batch data.
- Parameters:
persia_batch (PersiaBatch) – input data without embeddings..
device_id (int, optional) – the CUDA device to use for this process.
- Returns:
PersiaTrainingBatch that contains id_type_feature_embeddings.
- Return type:
persia.prelude.PersiaTrainingBatch
- get_embedding_size()
Get number of ids on all embedding servers.
- Return type:
List[int]
- load_checkpoint(src_dir, map_location=None, dense_filename='dense.pt', blocking=True)
Load the dense and embedding checkpoint from the source directory.
- Parameters:
src_dir (str) – source directory.
map_location (str, optional) – load the dense checkpoint to specific device.
dense_filename (str, optional) – dense checkpoint filename.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
- load_embedding(src_dir, blocking=True)
Load embeddings from
src_dir. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call towait_for_load_embedding. Setblocking=Falseto allow asyncronous computation, in which case the function will return immediately.- Parameters:
src_dir (str) – directory to load embeddings.
blocking (bool, optional) – dump embedding in blocking mode or not.
- load_torch_state_dict(torch_instance, src_dir, map_location=None)
Load a Pytorch state dict from the source directory and apply to torch_instance.
- Parameters:
torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to restore.
src_dir (str) – directory to load torch state dict.
map_location (str, optional) – load the dense checkpoint to specific device.
- prepare_features(persia_training_batch)
This function converts data from
PersiaTrainingBatchtotorch.Tensor.PersiaTrainingBatchcontains non_id_type_features, id_type_feature_embeddings and labels. But they can’t use directly in training before convert theTensortotorch.Tensor.- Parameters:
persia_training_batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features, labels, id_type_feature_embeddings and meta info.
- Returns:
the tuple of non_id_type_feature_tensors, id_type_feature_embedding_tensors and label_tensors.
- Return type:
Tuple[List[torch.Tensor], List[torch.Tensor], Optional[List[torch.Tensor]]]
- wait_for_dump_embedding()
Wait for the embedding dump process.
- wait_for_load_embedding()
Wait for the embedding load process.
- class persia.ctx.InferCtx(embedding_worker_address_list, *args, **kwargs)
Bases:
EmbeddingCtxSubclass of
EmbeddingCtxthat provides the inference functionality without nats-servers.Example for
InferCtx:import numpy as np from persia.ctx import InferCtx from persia.embedding.data import PersiaBatch, IDTypeFeatureWithSingleID device_id = 0 id_type_feature = IDTypeFeatureWithSingleID( "id_type_feature", np.array([1, 2, 3], np.uint64) ) persia_batch = PersiaBatch([id_type_feature], requires_grad=False) embedding_worker_address_list = [ "localhost: 8888", "localhost: 8889", "localhost: 8890" ] with InferCtx(embedding_worker_address_list, device_id=device_id) as infer_ctx: persia_training_batch = persia_context.get_embedding_from_bytes( persia_batch.to_bytes(), ) ( non_id_type_feature_tensors, id_type_feature_embedding_tensors, label_tensors )= persia_context.prepare_features(batch)
Note
The example cannot be run directly, you should launch the embedding-worker and embedding-parameter-server to ensure the example gets correct result.
- Parameters:
embedding_worker_addrs (List[str]) – embedding worker address(ip:port) list.
embedding_worker_address_list (List[str]) –
- wait_for_serving()
- class persia.ctx.PreprocessMode
Bases:
enum.EnumMode of preprocessing.
Used by
prepare_featuresto generate features of different datatypes.When set to
TRAIN,prepare_featureswill return a torch tensor withrequires_gradattribute set toTrue. When set toEVAL,prepare_featureswill return a torch tensor withrequires_gradattribute set toFalse.INFERENCEbehaves almost identical toEVAL, except thatINFERENCEallowsEmbeddingCtxto process thePersiaTrainingBatchwithout a target tensor.- EVAL = 2
- INFERENCE = 3
- TRAIN = 1
- class persia.ctx.TrainCtx(embedding_optimizer, dense_optimizer, grad_scalar_update_factor=4, backward_buffer_size=10, backward_workers_size=8, grad_update_buffer_size=60, lookup_emb_directly=True, mixed_precision=True, distributed_option=None, *args, **kwargs)
Bases:
EmbeddingCtxSubclass of
EmbeddingCtxthat implements a backward function to update the embeddings.Example for
TrainCtx:import torch import persia from persia.data import DataLoder, StreamingDataset device_id = 0 model = get_dnn_model() model.cuda(device_id) embedding_optimizer = persia.embedding.optim.SGD(lr=1e-3) dense_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.BCELoss(reduction="mean") prefetch_size = 15 stream_dataset = StreamingDataset(prefetch_size) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, device_id=device_id ) as ctx: dataloader = DataLoder(stream_dataset) for persia_training_batch in datalaoder: output, labels = ctx.forward(persia_training_batch) loss = loss_fn(output, labels[0]) scaled_loss = ctx.backward(loss)
If you want to train the PERSIA task in a distributed environment, you can set distributed_option to the corresponding option you want to use. Currently support Pytorch DDP (distributed data-parallel) (
DDPOption) and Bagua (BaguaDistributedOption). The default is Pytorch DDP. The default configuration is determined byget_default_distributed_optionwhen the environmentWORLD_SIZE > 1.You can configure the
DDPOptionto your specific requirements.import persia from persia.distributed import DDPOption backend = "nccl" # backend = "gloo" # If you want to train the PERSIA on the CPU cluster. ddp_option = DDPOption( backend=backend, init_method="tcp" ) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, distributed_option=ddp_option ) as ctx: ...
We also integrated Bagua to PERSIA as an alternative to PytorchDDP. Bagua is an advanced data-parallel framework, also developed by AI Platform @ Kuaishou. Using
BaguaDistributedOptionin place ofDDPOptioncan significantly speed up the training (See Bagua Benchmark). For more details on the algorithms used by and available options ofBaguaDistributedOption, please refer to Bagua tutorials.Example for
BaguaDistributedOption:from persia.distributed import BaguaDistributedOption algorithm = "gradient_allreduce" bagua_args = {} bagua_option = BaguaDistributedOption( algorithm, **bagua_args ) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, distributed_option=bagua_option ) as ctx: ...
- Parameters:
embedding_optimizer (persia.embedding.optim.Optimizer) – optimizer for the embedding parameters.
dense_optimizer (torch.optim.Optimizer) – optimizer for dense parameters.
grad_scalar_update_factor (float, optional) – update factor of
Gradscalarto ensure that loss scale is finite if setmixed_precision=True.backward_buffer_size (int, optional) – maximum number of gradients queued in the buffer between two backward steps.
backward_workers_size (int, optional) – number of workers sending embedding gradients in parallel.
grad_update_buffer_size (int, optional) – the size of gradient buffers. The buffer will cache the gradient tensor until the embedding update is finished.
lookup_emb_directly (bool, optional) – lookup embedding directly without a separate data loader.
mixed_precision (bool) – whether to enable mixed_precision.
distributed_option (DistributedBaseOption, optional) – option for distributed training.
- backward(loss, embedding_gradient_check_frequency=20)
Update the parameters of the current dense model and embedding model.
- Parameters:
loss (torch.Tensor) – loss of current batch.
embedding_gradient_check_frequency (int, optional) – how many batch_size to check gradient finite or not for current embedding.
- Return type:
torch.Tensor
- dump_checkpoint(dst_dir, dense_model_filename='dense.pt', jit_dense_model_filename='jit_dense.pt', opt_filename='opt.pt', blocking=True, with_jit_model=False)
Dump the dense and embedding checkpoint to destination directory.
- Parameters:
dst_dir (str) – destination directory.
dense_model_filename (str, optional) – dense model checkpoint filename.
jit_dense_model_filename (str, optional) – dense checkpoint filename after PyTorch jit.
opt_filename (str, optional) – optimizer checkpoint filename.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
with_jit_model (bool, optional) – dump dense checkpoint as jit script or not.
- load_checkpoint(src_dir, map_location=None, dense_model_filename='dense.pt', opt_filename='opt.pt', blocking=True)
Load the dense and embedding checkpoint from source directory.
- Parameters:
src_dir (str) – source directory.
map_location (str, optional) – load the dense checkpoint to specific device.
dense_model_filename (str, optional) – dense checkpoint filename.
opt_filename (str, optional) – optimizer checkpoint filename.
blocking (bool, optional) – dump embedding checkpoint in blocking mode or not.
- wait_servers_ready()
Wait until embedding servers are ready to serve.
- persia.ctx.eval_ctx(*args, **kwargs)
Get the
EmbeddingCtxwith theEVALmode.- Return type: