persia.ctx
Module Contents
- class persia.ctx.BaseCtx(threadpool_worker_size=10, device_id=None)
- Initializes a common context for other persia context, e.g. - DataCtx,- EmbeddingCtxand- TrainCtx. This class should not be instantiated directly.- Parameters:
- threadpool_worker_size (int) – rpc threadpool worker size. 
- device_id (int, optional) – the CUDA device to use for this process. 
 
 
- class persia.ctx.DataCtx(*args, **kwargs)
- Bases: - BaseCtx- Data context provides the communication functionality to data generator component. Used for sending a - PersiaBatchto the nn worker and embedding worker.- If you use the - DataCtxto send the- PersiaBatchon data-loader, you should use the- StreamingDatasetto receive the data on nn-worker.- On data-loader: - from persia.ctx import DataCtx from persia.embedding.data import PersiaBatch loader = make_loader() with DataCtx() as ctx: for (non_id_type_features, id_type_features, labels) in loader: batch_data = PersiaBatch( id_type_features=id_type_features, non_id_type_features, label, requires_grad=True ) ctx.send_data(persia_batch) - On nn-worker: - from persia.ctx import TrainCtx from persia.data import StreamingDataset, DataLoader buffer_size = 15 streaming_dataset = StreamingDataset(buffer_size) data_loader = DataLoader(streaming_dataset) with TrainCtx(...): for persia_training_batch in data_loader: ... - Note - The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result. - Parameters:
- threadpool_worker_size (int) – rpc threadpool worker size. 
- device_id (int, optional) – the CUDA device to use for this process. 
 
 - send_data(persia_batch)
- Send PersiaBatch from data loader to nn worker and embedding worker side. - Parameters:
- persia_batch (PersiaBatch) – - PersiaBatchthat haven’t been processed.
 
 
- class persia.ctx.EmbeddingCtx(preprocess_mode, model=None, embedding_config=None, *args, **kwargs)
- Bases: - BaseCtx- Provides the embedding-related functionality. - EmbeddingCtxcan run offline test or online inference depending on different preprocess_mode. The simplest way to get this context is by using- eval_ctxto get the- EmbeddingCtxinstance.- Example for - EmbeddingCtx:- from persia.ctx import EmbeddingCtx, PreprocessMode from persia.embedding.data import PersiaBatch model = get_dnn_model() loader = make_dataloader() device_id = 0 with EmbeddingCtx( PreprocessMode.EVAL, model=model, device_id=device_id ) as ctx: for (non_id_type_features, id_type_features, labels) in loader: persia_batch = PersiaBatch( id_type_features non_id_type_features=non_id_type_features, labels=labels requires_grad=False ) persia_training_batch = ctx.get_embedding_from_data(persia_batch) (output, label) = ctx.forward(persia_training_batch) - Note - The examples cannot be run directly, you should launch the nn_worker, embedding-worker, embedding-parameter-server, and nats-server to ensure the example gets the correct result. - Note - If you set - device_id=None, the training data and the model will be placed in host memory rather than in CUDA device memory by default.- Parameters:
- preprocess_mode (PreprocessMode) – different preprocess mode effect the behavior of - prepare_features.
- model (torch.nn.Module) – denese neural network PyTorch model. 
- embedding_config (EmbeddingConfig, optional) – the embedding configuration that will be sent to the embedding server. 
 
 - clear_embeddings()
- Clear all embeddings on all embedding servers. 
 - configure_embedding_parameter_servers(embedding_config)
- Apply - EmbeddingConfigto embedding servers.- Parameters:
- embedding_config (EmbeddingConfig) – the embedding configuration that will be sent to the embedding server. 
 
 - dump_checkpoint(dst_dir, dense_filename='dense.pt', jit_dense_filename='jit_dense.pt', blocking=True, with_jit_model=False)
- Save the model checkpoint (both dense and embedding) to the destination directory. - Parameters:
- dst_dir (str) – destination directory. 
- dense_filename (str, optional) – dense checkpoint filename. 
- jit_dense_filename (str, optional) – dense checkpoint filename after PyTorch jit script. 
- blocking (bool, optional) – dump embedding checkpoint in blocking mode or not. 
- with_jit_model (bool, optional) – dump jit script dense checkpoint or not. 
 
 
 - dump_embedding(dst_dir, blocking=True)
- Dump embeddings to the destination directory. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call to - wait_for_dump_embedding. Set- blocking=Falseto allow asyncronous computation, in which case the function will return immediately.- wait_for_dump_embeddingto wait until finished if- blocking=False.- Parameters:
- dst_dir (str) – destination directory. 
- blocking (bool, optional) – dump embedding in blocking mode or not. 
 
 
 - dump_torch_state_dict(torch_instance, dst_dir, file_name, is_jit=False)
- Dump a Pytorch model or optimizer’s state dict to the destination directory. - Parameters:
- torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to be dumped. 
- dst_dir (str) – destination directory. 
- file_name (str) – destination filename. 
- is_jit (bool, optional) – whether to dump model as jit script. 
 
 
 - forward(batch)
- Call - prepare_featuresand then do a forward step of the model in context.- Parameters:
- batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features ,labels, id_type_feature_embeddings and meta info. 
- Returns:
- the tuple of output data and target data. 
- Return type:
- Tuple[torch.Tensor, Optional[torch.Tensor]] 
 
 - get_embedding_from_bytes(data, device_id=None)
- Get embeddings of the serialized input batch data. - Parameters:
- data (PersiaBatch) – serialized input data without embeddings. 
- device_id (int, optional) – the CUDA device to use for this process. 
 
- Returns:
- PersiaTrainingBatch that contains id_type_feature_embeddings. 
- Return type:
- persia.prelude.PersiaTrainingBatch 
 
 - get_embedding_from_data(persia_batch, device_id=None)
- Get embeddings of the serialized input batch data. - Parameters:
- persia_batch (PersiaBatch) – input data without embeddings.. 
- device_id (int, optional) – the CUDA device to use for this process. 
 
- Returns:
- PersiaTrainingBatch that contains id_type_feature_embeddings. 
- Return type:
- persia.prelude.PersiaTrainingBatch 
 
 - get_embedding_size()
- Get number of ids on all embedding servers. - Return type:
- List[int] 
 
 - load_checkpoint(src_dir, map_location=None, dense_filename='dense.pt', blocking=True)
- Load the dense and embedding checkpoint from the source directory. - Parameters:
- src_dir (str) – source directory. 
- map_location (str, optional) – load the dense checkpoint to specific device. 
- dense_filename (str, optional) – dense checkpoint filename. 
- blocking (bool, optional) – dump embedding checkpoint in blocking mode or not. 
 
 
 - load_embedding(src_dir, blocking=True)
- Load embeddings from - src_dir. By default, this function is synchronous and will wait for the completion of embedding loading before returning. This is done internally through a call to- wait_for_load_embedding. Set- blocking=Falseto allow asyncronous computation, in which case the function will return immediately.- Parameters:
- src_dir (str) – directory to load embeddings. 
- blocking (bool, optional) – dump embedding in blocking mode or not. 
 
 
 - load_torch_state_dict(torch_instance, src_dir, map_location=None)
- Load a Pytorch state dict from the source directory and apply to torch_instance. - Parameters:
- torch_instance (torch.nn.Module or torch.optim.Optimizer) – dense model or optimizer to restore. 
- src_dir (str) – directory to load torch state dict. 
- map_location (str, optional) – load the dense checkpoint to specific device. 
 
 
 - prepare_features(persia_training_batch)
- This function converts data from - PersiaTrainingBatchto- torch.Tensor.- PersiaTrainingBatchcontains non_id_type_features, id_type_feature_embeddings and labels. But they can’t use directly in training before convert the- Tensorto- torch.Tensor.- Parameters:
- persia_training_batch (PersiaTrainingBatch) – training data provided by PERSIA upstream including non_id_type_features, labels, id_type_feature_embeddings and meta info. 
- Returns:
- the tuple of non_id_type_feature_tensors, id_type_feature_embedding_tensors and label_tensors. 
- Return type:
- Tuple[List[torch.Tensor], List[torch.Tensor], Optional[List[torch.Tensor]]] 
 
 - wait_for_dump_embedding()
- Wait for the embedding dump process. 
 - wait_for_load_embedding()
- Wait for the embedding load process. 
 
- class persia.ctx.InferCtx(embedding_worker_address_list, *args, **kwargs)
- Bases: - EmbeddingCtx- Subclass of - EmbeddingCtxthat provides the inference functionality without nats-servers.- Example for - InferCtx:- import numpy as np from persia.ctx import InferCtx from persia.embedding.data import PersiaBatch, IDTypeFeatureWithSingleID device_id = 0 id_type_feature = IDTypeFeatureWithSingleID( "id_type_feature", np.array([1, 2, 3], np.uint64) ) persia_batch = PersiaBatch([id_type_feature], requires_grad=False) embedding_worker_address_list = [ "localhost: 8888", "localhost: 8889", "localhost: 8890" ] with InferCtx(embedding_worker_address_list, device_id=device_id) as infer_ctx: persia_training_batch = persia_context.get_embedding_from_bytes( persia_batch.to_bytes(), ) ( non_id_type_feature_tensors, id_type_feature_embedding_tensors, label_tensors )= persia_context.prepare_features(batch) - Note - The example cannot be run directly, you should launch the embedding-worker and embedding-parameter-server to ensure the example gets correct result. - Parameters:
- embedding_worker_addrs (List[str]) – embedding worker address(ip:port) list. 
- embedding_worker_address_list (List[str]) – 
 
 - wait_for_serving()
 
- class persia.ctx.PreprocessMode
- Bases: - enum.Enum- Mode of preprocessing. - Used by - prepare_featuresto generate features of different datatypes.- When set to - TRAIN,- prepare_featureswill return a torch tensor with- requires_gradattribute set to- True. When set to- EVAL,- prepare_featureswill return a torch tensor with- requires_gradattribute set to- False.- INFERENCEbehaves almost identical to- EVAL, except that- INFERENCEallows- EmbeddingCtxto process the- PersiaTrainingBatchwithout a target tensor.- EVAL = 2
 - INFERENCE = 3
 - TRAIN = 1
 
- class persia.ctx.TrainCtx(embedding_optimizer, dense_optimizer, grad_scalar_update_factor=4, backward_buffer_size=10, backward_workers_size=8, grad_update_buffer_size=60, lookup_emb_directly=True, mixed_precision=True, distributed_option=None, *args, **kwargs)
- Bases: - EmbeddingCtx- Subclass of - EmbeddingCtxthat implements a backward function to update the embeddings.- Example for - TrainCtx:- import torch import persia from persia.data import DataLoder, StreamingDataset device_id = 0 model = get_dnn_model() model.cuda(device_id) embedding_optimizer = persia.embedding.optim.SGD(lr=1e-3) dense_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.BCELoss(reduction="mean") prefetch_size = 15 stream_dataset = StreamingDataset(prefetch_size) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, device_id=device_id ) as ctx: dataloader = DataLoder(stream_dataset) for persia_training_batch in datalaoder: output, labels = ctx.forward(persia_training_batch) loss = loss_fn(output, labels[0]) scaled_loss = ctx.backward(loss) - If you want to train the PERSIA task in a distributed environment, you can set distributed_option to the corresponding option you want to use. Currently support Pytorch DDP (distributed data-parallel) ( - DDPOption) and Bagua (- BaguaDistributedOption). The default is Pytorch DDP. The default configuration is determined by- get_default_distributed_optionwhen the environment- WORLD_SIZE > 1.- You can configure the - DDPOptionto your specific requirements.- import persia from persia.distributed import DDPOption backend = "nccl" # backend = "gloo" # If you want to train the PERSIA on the CPU cluster. ddp_option = DDPOption( backend=backend, init_method="tcp" ) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, distributed_option=ddp_option ) as ctx: ... - We also integrated Bagua to PERSIA as an alternative to PytorchDDP. Bagua is an advanced data-parallel framework, also developed by AI Platform @ Kuaishou. Using - BaguaDistributedOptionin place of- DDPOptioncan significantly speed up the training (See Bagua Benchmark). For more details on the algorithms used by and available options of- BaguaDistributedOption, please refer to Bagua tutorials.- Example for - BaguaDistributedOption:- from persia.distributed import BaguaDistributedOption algorithm = "gradient_allreduce" bagua_args = {} bagua_option = BaguaDistributedOption( algorithm, **bagua_args ) with TrainCtx( embedding_optimizer, dense_optimizer, model=model, distributed_option=bagua_option ) as ctx: ... - Parameters:
- embedding_optimizer (persia.embedding.optim.Optimizer) – optimizer for the embedding parameters. 
- dense_optimizer (torch.optim.Optimizer) – optimizer for dense parameters. 
- grad_scalar_update_factor (float, optional) – update factor of - Gradscalarto ensure that loss scale is finite if set- mixed_precision=True.
- backward_buffer_size (int, optional) – maximum number of gradients queued in the buffer between two backward steps. 
- backward_workers_size (int, optional) – number of workers sending embedding gradients in parallel. 
- grad_update_buffer_size (int, optional) – the size of gradient buffers. The buffer will cache the gradient tensor until the embedding update is finished. 
- lookup_emb_directly (bool, optional) – lookup embedding directly without a separate data loader. 
- mixed_precision (bool) – whether to enable mixed_precision. 
- distributed_option (DistributedBaseOption, optional) – option for distributed training. 
 
 - backward(loss, embedding_gradient_check_frequency=20)
- Update the parameters of the current dense model and embedding model. - Parameters:
- loss (torch.Tensor) – loss of current batch. 
- embedding_gradient_check_frequency (int, optional) – how many batch_size to check gradient finite or not for current embedding. 
 
- Return type:
- torch.Tensor 
 
 - dump_checkpoint(dst_dir, dense_model_filename='dense.pt', jit_dense_model_filename='jit_dense.pt', opt_filename='opt.pt', blocking=True, with_jit_model=False)
- Dump the dense and embedding checkpoint to destination directory. - Parameters:
- dst_dir (str) – destination directory. 
- dense_model_filename (str, optional) – dense model checkpoint filename. 
- jit_dense_model_filename (str, optional) – dense checkpoint filename after PyTorch jit. 
- opt_filename (str, optional) – optimizer checkpoint filename. 
- blocking (bool, optional) – dump embedding checkpoint in blocking mode or not. 
- with_jit_model (bool, optional) – dump dense checkpoint as jit script or not. 
 
 
 - load_checkpoint(src_dir, map_location=None, dense_model_filename='dense.pt', opt_filename='opt.pt', blocking=True)
- Load the dense and embedding checkpoint from source directory. - Parameters:
- src_dir (str) – source directory. 
- map_location (str, optional) – load the dense checkpoint to specific device. 
- dense_model_filename (str, optional) – dense checkpoint filename. 
- opt_filename (str, optional) – optimizer checkpoint filename. 
- blocking (bool, optional) – dump embedding checkpoint in blocking mode or not. 
 
 
 - wait_servers_ready()
- Wait until embedding servers are ready to serve. 
 
- persia.ctx.eval_ctx(*args, **kwargs)
- Get the - EmbeddingCtxwith the- EVALmode.- Return type: