add ps cache manager

This commit is contained in:
lizhenyu 2020-11-04 15:36:39 +08:00
parent 1033166d8a
commit e3f7ae61db
53 changed files with 1327 additions and 128 deletions

View File

@ -194,6 +194,14 @@ if (ENABLE_GPU)
)
endif ()
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
install(
TARGETS ps_cache
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
if (ENABLE_SERVING OR ENABLE_TESTCASES)
file(GLOB_RECURSE LIBEVENT_LIB_LIST
${libevent_LIBPATH}/libevent*

View File

@ -308,7 +308,7 @@ elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
else ()
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core)
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
if (${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(mindspore ibverbs rdmacm)
endif()

View File

@ -75,6 +75,9 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector<AddressPtr> &inputs
if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) {
node_so_ = CUST_AICPU_OPS_SO_NAME;
node_name_ = kCustRunApi;
} else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) {
node_so_ = AICPU_OPS_SO_NAME;
node_name_ = kCustRunApi;
} else {
node_so_ = AICPU_OPS_SO_NAME;
}
@ -161,6 +164,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) {
node_so_ = CUST_AICPU_OPS_SO_NAME;
node_name_ = kCustRunApi;
} else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) {
node_so_ = AICPU_OPS_SO_NAME;
node_name_ = kCustRunApi;
} else {
node_so_ = AICPU_OPS_SO_NAME;
}

View File

@ -49,6 +49,7 @@ constexpr auto kIdentity = "Identity";
constexpr auto kUpdateCache = "UpdateCache";
constexpr auto kCustRunApi = "RunCpuKernel";
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity};
const std::set<std::string> kCacheKernelOps{kUpdateCache};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message

View File

@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h"
#include <vector>
#include <algorithm>
#include "ps/worker.h"
namespace mindspore {
@ -38,10 +39,13 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
}
std::vector<size_t> keys{key_, key_, key_};
std::vector<size_t> values;
values.insert(values.end(), input_shape.begin(), input_shape.end());
values.insert(values.end(), indices_shape.begin(), indices_shape.end());
values.insert(values.end(), output_shape.begin(), output_shape.end());
std::vector<float> values;
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(values),
[](size_t dim) -> float { return SizeToFloat(dim); });
std::transform(indices_shape.begin(), indices_shape.end(), std::back_inserter(values),
[](size_t dim) -> float { return SizeToFloat(dim); });
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(values),
[](size_t dim) -> float { return SizeToFloat(dim); });
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()),

View File

@ -72,6 +72,23 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, con
return Launch(inputs, workspace, outputs);
}
void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids,
const float *update_vals, size_t ids_size) {
size_t copy_lens = outer_dim_size_ * sizeof(float);
for (size_t i = 0; i < ids_size; ++i) {
int index = lookup_ids[i] - offset_;
if (index >= 0 && index < SizeToInt(first_dim_size_)) {
auto ret =
memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
}
} else {
MS_LOG(EXCEPTION) << "UpdateEmbeddings index invalid.";
}
}
}
const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; }
const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); }

View File

@ -35,7 +35,8 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK
bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals,
size_t ids_size) override;
const std::vector<size_t> &input_sizes() const override;
const std::vector<size_t> &output_sizes() const override;
const std::vector<size_t> &workspace_sizes() const override;

View File

@ -38,7 +38,8 @@ class PServerKernel {
virtual void ReInit(const std::vector<std::vector<size_t>> &) {}
virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0;
virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals,
size_t ids_size) {}
virtual const std::vector<size_t> &input_sizes() const = 0;
virtual const std::vector<size_t> &output_sizes() const = 0;
virtual const std::vector<size_t> &workspace_sizes() const = 0;

View File

@ -56,6 +56,7 @@
#include "toolchain/adx_datadump_server.h"
#if ENABLE_CPU && ENABLE_D
#include "ps/util.h"
#include "ps/ps_cache/ps_cache_manager.h"
#endif
namespace mindspore {
@ -487,11 +488,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
// adjust kernel
AdjustKernel(root_graph);
#if ENABLE_CPU && ENABLE_D
if (ps::Util::IsParamServerMode()) {
CheckPSModeConsistence(root_graph);
// Assign parameter keys.
AssignParamKey(root_graph);
}
InitPsWorker(root_graph);
#endif
// assign stream
AssignStream(NOT_NULL(root_graph));
@ -568,6 +565,9 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
}
// adjust execution order because merge child graph and other special operations
AdjustKernel(graph);
#if ENABLE_CPU && ENABLE_D
InitPsWorker(graph);
#endif
// Reorder optimizer order
auto execution_order = graph->execution_order();
Reorder(&execution_order);
@ -644,6 +644,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens
#if ENABLE_CPU && ENABLE_D
// Initialize parameter server
InitPSParamAndOptim(kernel_graph, inputs);
std::string channel_name;
if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) {
ps::ps_cache_instance.IncreaseGraphStep(channel_name);
}
#endif
{
// run task on device

View File

@ -21,6 +21,9 @@
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/comm_manager.h"
#include "utils/scoped_long_running.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
namespace mindspore {
namespace session {

View File

@ -64,6 +64,7 @@
#include "utils/ms_context.h"
#if ENABLE_CPU && ENABLE_GPU
#include "ps/util.h"
#include "ps/ps_cache/ps_cache_manager.h"
#endif
namespace mindspore {
@ -237,6 +238,12 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
auto input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
#if ENABLE_CPU && ENABLE_GPU
const std::string &param_name = input_node->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
continue;
}
#endif
auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
@ -300,16 +307,11 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
HardwareOptimize(graph);
// Graph kernel fusion optimization
GraphKernelOptimize(graph);
#if ENABLE_CPU && ENABLE_GPU
if (ps::Util::IsParamServerMode()) {
CheckPSModeConsistence(graph);
// Assign parameter keys.
AssignParamKey(graph);
}
#endif
// Start gpu kernel runtime
StartKernelRT();
#if ENABLE_CPU && ENABLE_GPU
InitPsWorker(graph);
#endif
// Assign CUDA streams
AssignStream(graph);
// Dump .pb graph before remove nop nodes
@ -374,6 +376,12 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
int kernel_num = kernel_graph->execution_order().size();
int64_t loopsize = (kernel_num > 1) ? ConfigManager::GetInstance().gpu_loopsink_size() : 1;
for (int64_t i = 0; i < loopsize; i++) {
#if ENABLE_CPU && ENABLE_GPU
std::string channel_name;
if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) {
ps::ps_cache_instance.IncreaseGraphStep(channel_name);
}
#endif
Execute(kernel_graph);
}
// In pynative mode, device addresses of tensors in value nodes need be clean.

View File

@ -41,8 +41,10 @@
#include "utils/trace_base.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/worker.h"
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/common.h"
#include "ps/util.h"
#include "abstract/abstract_value.h"
#endif
namespace mindspore {
@ -1125,6 +1127,12 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
}
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
const std::string &param_name = input_node->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
continue;
}
#endif
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size,
@ -1715,8 +1723,64 @@ void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<ten
}
}
bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) {
auto kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
for (const auto &kernel_node : kernel_graph->execution_order()) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == kGetNextOpName) {
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
*channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
return true;
}
}
return false;
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) {
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::Util::IsRoleOfWorker()) {
return;
}
CheckPSModeConsistence(kernel_graph);
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
if (!ps::ps_cache_instance.initialized_ps_cache()) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto context = runtime_instance->context();
const auto &kernels = kernel_graph->execution_order();
if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
GetBatchElements(kernels[0]);
ps::ps_cache_instance.Initialize();
}
ps::ps_cache_instance.DoProcessData(device_id_, context);
}
} else {
// Assign parameter keys.
AssignParamKey(kernel_graph);
}
}
void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const {
auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
<< types;
}
size_t batch_elements = 1;
const auto &shape = shapes[0];
for (size_t i = 0; i < shape.size(); ++i) {
batch_elements *= shape[i];
}
ps::ps_cache_instance.set_batch_elements(batch_elements);
}
void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const {
auto input_nodes = kernel_graph->inputs();
for (const auto &input_node : input_nodes) {
if (!input_node->isa<Parameter>()) {
@ -1725,8 +1789,9 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) {
auto pk_node = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
auto param_info_ptr = pk_node->param_info();
if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
const std::string &param_name = pk_node->fullname_with_scope();
const std::string &param_name = pk_node->fullname_with_scope();
if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
!ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
<< "] in server, this parameter is used by kernel which executes in device";
}
@ -1734,10 +1799,6 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) {
}
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
if (!ps::Util::IsRoleOfWorker()) {
MS_LOG(INFO) << "Not parameter server mode.";
return;
}
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
for (auto &node : node_list) {
@ -1775,16 +1836,8 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
return;
}
std::vector<tensor::TensorPtr> inputs(inputs_const);
size_t input_ctrl_size = 1;
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->input_ctrl_tensors()) {
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
}
auto input_nodes = kernel_graph->inputs();
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
for (size_t i = 0; i < inputs.size(); ++i) {

View File

@ -99,9 +99,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
// get graph id in child graphs by ME front anf node pointer
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const;
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph);
void AssignParamKey(const KernelGraphPtr &kernel_graph);
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
bool IsGetNextGraph(const GraphId &graph_id, std::string *channel_name);
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
std::string *error_msg) const {
return true;
@ -195,6 +195,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
void GetBatchElements(const AnfNodePtr &kernel_node) const;
void InitPsWorker(const KernelGraphPtr &kernel_graph);
#endif
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
@ -207,6 +212,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
#if !defined(_WIN32) && !defined(_WIN64)
std::shared_ptr<Debugger> debugger_;
#endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
bool initialized_ps_cache_{false};
#endif
};
using SessionPtr = std::shared_ptr<session::SessionBasic>;

View File

@ -24,6 +24,9 @@
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif
namespace mindspore {
namespace parallel {
@ -514,6 +517,12 @@ Status GatherV2PInfo::InferBias() {
if (repeated_calc_num_ > 1) {
rank = rank / repeated_calc_num_;
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
bias_ = 0;
return SUCCESS;
}
#endif
bias_ = rank / params_strategy.at(1) * slice_size_;
return SUCCESS;
}

View File

@ -46,10 +46,18 @@
#include "ir/anf.h"
#include "ir/param_info.h"
#include "ir/tensor.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#endif
namespace mindspore {
namespace parallel {
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
return false;
}
#endif
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();

View File

@ -44,6 +44,9 @@
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#endif
using mindspore::tensor::Tensor;
@ -3036,6 +3039,11 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
}
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
return false;
}
#endif
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());

View File

@ -201,6 +201,7 @@ else ()
if (${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
endif ()
target_link_libraries(_c_dataengine PRIVATE ps_cache)
endif ()
endif ()

View File

@ -322,6 +322,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
bool profiling, int32_t *push_time) {
std::vector<device::DataItemGpu> items;
double start_time;
bool ps_data_prefetch = false;
for (int i = 0; i < data_size.size(); i++) {
device::DataItemGpu data_item;
data_item.data_len_ = data_size[i];
@ -334,6 +335,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond();
}
// Data prefetch only when PS mode enables cache.
if ((!ps_data_prefetch) && (items.size() > 0)) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_);
ps_data_prefetch = true;
}
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (profiling) {
double end_time = ProfilingTime::GetCurMilliSecond();

View File

@ -24,6 +24,7 @@
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#ifdef ENABLE_TDTQUE
#include "minddata/dataset/util/queue.h"

View File

@ -17,6 +17,8 @@
#include "utils/ms_utils.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/util/log_adapter.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
namespace mindspore {
namespace dataset {
static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr;
@ -48,6 +50,10 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond();
}
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_);
}
if (tdt::TdtHostPushData(channel_name, items) != 0) {
return FAILED;
}

View File

@ -308,7 +308,14 @@ PYBIND11_MODULE(_c_expression, m) {
.def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
.def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
.def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.");
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.")
.def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.")
.def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize,
"Insert hash table size with new parameter name.")
.def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.")
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())

View File

@ -52,6 +52,7 @@
#include "ps/common.h"
#include "ps/util.h"
#include "ps/worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif
#if (ENABLE_GE || ENABLE_D)
@ -921,6 +922,11 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) {
return true;
}
#endif
MS_LOG(INFO) << "Start InitDataSet Entry";
ShapeVector int_input_indexes;
(void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
@ -966,7 +972,17 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id);
}
ConfigManager::GetInstance().set_iter_num(size);
// PS mode does not support loop sink.
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfWorker()) {
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
ConfigManager::GetInstance().set_iter_num(1);
} else {
#endif
ConfigManager::GetInstance().set_iter_num(size);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
}
#endif
if (!(*runner.run)) {
// empty function
@ -981,7 +997,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
}
MS_LOG(DEBUG) << "InitDataSetVm End.";
return true;
}
} // namespace pipeline
void ResetOpId() { mindspore::id_generator::reset_id(); }

View File

@ -14,22 +14,20 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
endif ()
if (NOT ENABLE_D)
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
endif()
if (NOT ENABLE_GPU)
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
endif()
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc")
add_subdirectory(ps_cache)
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})

View File

@ -64,6 +64,7 @@ constexpr int64_t kInitWeightToOptimIdCmd = 11;
constexpr int64_t kInitOptimInputsShapeCmd = 12;
constexpr int64_t kInitKeyToPushNodeIdCmd = 13;
constexpr int64_t kInitEmbeddingsCmd = 20;
constexpr int64_t kUpdateEmbeddingsCmd = 21;
constexpr int64_t kCheckReadyForPushCmd = 25;
constexpr int64_t kCheckReadyForPullCmd = 26;
constexpr int64_t kEmbeddingLookupCmd = 30;

View File

@ -51,6 +51,8 @@
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/random_normal/random_normal.h"
namespace mindspore {
namespace ps {
@ -100,6 +102,7 @@ class ParameterServer {
void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
ParameterServer *ps_;
@ -118,13 +121,15 @@ class ParameterServer {
void InitWeight(const Key &key, const WeightPtr &weight);
void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
const ParamInitInfo &param_init_info);
bool HasWeight(const Key &key);
void Finalize();
void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
bool ReadyForUpdateWeights();
bool ReadyForPush(const Key &key);
bool ReadyForPull(const Key &key);
@ -193,6 +198,7 @@ void ParameterServer<T>::ServerHandler::Init() {
handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
}
@ -302,7 +308,17 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
for (int64_t k = 0; k < lens[2]; k++) {
output_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
}
ps_->InitEmbeddingTable(key, shapes);
ParamInitInfo param_init_info;
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
param_init_info.param_type_ = static_cast<ParamType>(lens[3]);
if (param_init_info.param_type_ == kWeight) {
param_init_info.global_seed_ = static_cast<size_t>(lens[4]);
param_init_info.op_seed_ = static_cast<size_t>(lens[5]);
} else if (param_init_info.param_type_ == kAccumulation) {
param_init_info.init_val_ = req_data.vals[index];
}
}
ps_->InitEmbeddingTable(key, shapes, param_init_info);
}
template <typename T>
@ -338,6 +354,18 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
const Key &key = req_data.keys[0];
const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size());
const Values &update_vals = req_data.vals;
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
@ -476,7 +504,8 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
template <typename T>
void ParameterServer<T>::InitEmbeddingTable(
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
const ParamInitInfo &param_init_info) {
MS_EXCEPTION_IF_NULL(shapes);
if (weights_.count(key) == 0) {
std::shared_ptr<PServerKernel> lookup =
@ -493,8 +522,18 @@ void ParameterServer<T>::InitEmbeddingTable(
T *embedding_data = embedding->data();
std::default_random_engine engine;
std::normal_distribution<float> random(0, 0.01);
for (size_t i = 0; i < total_dims; i++) {
embedding_data[i] = random(engine);
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
if (param_init_info.param_type_ == kWeight) {
InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data);
} else if (param_init_info.param_type_ == kAccumulation) {
for (size_t i = 0; i < total_dims; i++) {
embedding_data[i] = param_init_info.init_val_;
}
}
} else {
for (size_t i = 0; i < total_dims; i++) {
embedding_data[i] = random(engine);
}
}
weights_[key] = embedding;
tokens_[key] = 0;
@ -673,6 +712,23 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
res->lens.push_back(res->vals.size());
}
template <typename T>
void ParameterServer<T>::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
if (weights_.count(key) == 0) {
MS_LOG(ERROR) << "Invalid embedding table key " << key;
return;
}
if (embedding_lookup_ops_.count(key) == 0) {
MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
return;
}
WeightPtr table_ptr = weights_[key];
MS_EXCEPTION_IF_NULL(table_ptr);
std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
MS_EXCEPTION_IF_NULL(table_lookup_op);
table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
}
template <typename T>
inline bool ParameterServer<T>::ReadyForUpdateWeights() {
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();

View File

@ -70,9 +70,16 @@ void PsCacheManager::InsertWeightInitInfo(const std::string &param_name, size_t
MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
}
auto &hash_table_info = iter->second;
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
return;
}
hash_table_info.param_init_info_.param_type_ = kWeight;
hash_table_info.param_init_info_.global_seed_ = global_seed;
hash_table_info.param_init_info_.op_seed_ = op_seed;
if (CheckFinishInsertInitInfo()) {
finish_insert_init_info_ = true;
insert_init_info_.notify_one();
}
}
void PsCacheManager::InsertAccumuInitInfo(const std::string &param_name, float init_val) {
@ -81,8 +88,26 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string &param_name, float i
MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
}
auto &hash_table_info = iter->second;
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
return;
}
hash_table_info.param_init_info_.param_type_ = kAccumulation;
hash_table_info.param_init_info_.init_val_ = init_val;
if (CheckFinishInsertInitInfo()) {
finish_insert_init_info_ = true;
insert_init_info_.notify_one();
}
}
bool PsCacheManager::CheckFinishInsertInitInfo() const {
for (const auto &item : hash_tables_) {
const auto &hash_table_info = item.second;
const auto &param_init_info = hash_table_info.param_init_info_;
if (param_init_info.param_type_ == kUnKnown) {
return false;
}
}
return true;
}
void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) {
@ -113,35 +138,49 @@ void PsCacheManager::Initialize() {
}
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_);
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_);
InitParameterServer();
AddEmbeddingTable();
AllocMemForHashTable();
SetLocalIdRank();
initialized_ps_cache_ = true;
}
void PsCacheManager::InitParameterServer() {
void PsCacheManager::AddEmbeddingTable() const {
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
size_t row_count = item.second.vocab_size;
std::vector<size_t> keys{key, key, key, key};
// if worker role
worker.AddEmbeddingTable(key, row_count);
}
}
void PsCacheManager::InitParameterServer() {
std::unique_lock<std::mutex> locker(data_mutex_);
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; });
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
std::vector<size_t> keys{key, key, key, key, key, key};
std::vector<float> values{
SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1};
std::vector<int64_t> lens{2, 2, 3};
const auto &hash_table_info = item.second;
const auto &param_init_info = hash_table_info.param_init_info_;
if (param_init_info.param_type_ == kWeight) {
lens.push_back(0);
values.push_back(SizeToFloat(param_init_info.global_seed_));
values.push_back(SizeToFloat(param_init_info.op_seed_));
} else if (param_init_info.param_type_ == kAccumulation) {
lens.push_back(1);
values.push_back(param_init_info.init_val_);
} else if (param_init_info.param_type_ == kAccumulation) {
lens.push_back(2);
}
values.push_back(param_init_info.init_val_);
lens.push_back(param_init_info.global_seed_);
lens.push_back(param_init_info.op_seed_);
// if worker role
worker.AddEmbeddingTable(key, row_count);
worker.InitPSEmbeddingTable(keys, values, lens);
}
finish_init_parameter_server_ = true;
data_prase_.notify_one();
}
void PsCacheManager::AllocMemForHashTable() {
@ -208,10 +247,538 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
}
if (graph_step_ == 0) {
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
}
graph_step_++;
set_channel_name(channel_name);
PsDataPrefetch::GetInstance().TryWakeChannel(channel_name);
data_prase_.notify_one();
}
void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
if (!initialized_ps_cache_) {
MS_LOG(EXCEPTION) << "PS cache does not init.";
}
auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
process_data_thread.detach();
}
void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
embedding_device_cache_->cache_->InitDevice(device_id, context);
InitParameterServer();
while (true) {
ProcessData();
}
}
void PsCacheManager::ProcessData() {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
struct timeval start_time, end_time;
const uint64_t kUSecondInSecond = 1000000;
(void)gettimeofday(&start_time, nullptr);
auto channel = channel_name();
if (channel.empty()) {
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return !channel_name_.empty(); });
}
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
if (data == nullptr) {
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
std::unique_lock<std::mutex> locker(data_mutex_);
(void)data_prase_.wait_for(locker, std::chrono::milliseconds(100));
return;
}
IncreaseStep();
auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_);
auto batch_ids = reinterpret_cast<int *>(data);
auto batch_ids_len = data_size / sizeof(int);
std::unique_ptr<int[]> hash_index(new int[batch_ids_len]);
if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) {
MS_LOG(EXCEPTION) << "Process data memset failed.";
}
// Get hash swap in/out index and ids.
ParseData(batch_ids, batch_ids_len, hash_index.get());
for (const auto &item : hash_tables_) {
auto key = worker.GetParamKey(item.first);
auto hash_info = item.second;
HashSwapHostToServer(key, hash_info);
HashSwapDeviceToHost(hash_info);
HashSwapServerToHost(key, hash_info);
HashSwapHostToDevice(hash_info);
}
// Replace the batch_ids by hash index for getNext-op getting hash index as input.
if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) {
MS_LOG(EXCEPTION) << "Process data memcpy failed.";
}
embedding_device_cache_->cache_->SynchronizeStream();
// Finish the data process and notify data prefetch.
PsDataPrefetch::GetInstance().FinalizeData(channel_name_);
(void)gettimeofday(&end_time, nullptr);
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_
<< ",graph step:" << graph_running_step_ << " channel name:" << channel_name_
<< ", time cost:" << cost / 1000 << "ms).";
}
void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
MS_EXCEPTION_IF_NULL(batch_ids);
MS_EXCEPTION_IF_NULL(hash_index);
for (size_t i = 0; i < batch_ids_len; i++) {
bool need_swap_host_to_device = true;
bool need_swap_device_to_host = true;
auto id = batch_ids[i];
if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) {
hash_index[i] = -1;
continue;
}
hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device);
if (need_swap_host_to_device) {
ParseHostDataHostToDevice(id);
}
if (need_swap_device_to_host) {
ParseHostDataDeviceToHost(id);
}
}
// Each 1000 step prints ps cache hit rate.
if (data_step_ % 1000 == 0) {
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%.";
}
}
void PsCacheManager::WaitGraphRun() {
MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
std::unique_lock<std::mutex> locker(data_mutex_);
if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) {
MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_
<< ", graph running step:" << graph_running_step_ << ").";
}
set_current_graph_step();
}
int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) {
MS_EXCEPTION_IF_NULL(need_swap_device_to_host);
MS_EXCEPTION_IF_NULL(need_swap_host_to_device);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
MS_EXCEPTION_IF_NULL(device_to_host_index);
MS_EXCEPTION_IF_NULL(device_to_host_ids);
MS_EXCEPTION_IF_NULL(host_to_device_index);
MS_EXCEPTION_IF_NULL(host_to_device_ids);
auto device_hash_map = embedding_device_cache_->device_hash_map_;
MS_EXCEPTION_IF_NULL(device_hash_map);
int index = 0;
auto iter = device_hash_map->id_iter(id);
if (device_hash_map->IsIdExist(iter)) {
*need_swap_device_to_host = false;
*need_swap_host_to_device = false;
index = iter->second;
if (device_hash_map->hash_step(index) != data_step_) {
statistics_info_.hash_hit_count_++;
device_hash_map->set_hash_step(index, data_step_);
}
} else {
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
while (true) {
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
&(statistics_info_.device_to_host_size_));
if (index == INVALID_INDEX_VALUE) {
WaitGraphRun();
continue;
}
host_to_device_index[statistics_info_.host_to_device_size_] = index;
host_to_device_ids[statistics_info_.host_to_device_size_] = id;
statistics_info_.host_to_device_size_++;
*need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size;
break;
}
}
return index;
}
void PsCacheManager::ParseHostDataHostToDevice(size_t id) {
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
MS_EXCEPTION_IF_NULL(host_to_server_index);
MS_EXCEPTION_IF_NULL(host_to_server_ids);
MS_EXCEPTION_IF_NULL(server_to_host_index);
MS_EXCEPTION_IF_NULL(server_to_host_ids);
MS_EXCEPTION_IF_NULL(host_to_device_index);
auto host_hash_map = embedding_host_cache_->host_hash_map_;
MS_EXCEPTION_IF_NULL(host_hash_map);
auto iter = host_hash_map->id_iter(id);
if (host_hash_map->IsIdExist(iter)) {
auto index = iter->second;
if (host_hash_map->hash_step(index) != data_step_) {
host_hash_map->set_hash_step(index, data_step_);
}
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
} else {
while (true) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
if (index == INVALID_INDEX_VALUE) {
WaitGraphRun();
continue;
}
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
server_to_host_index[statistics_info_.server_to_host_size_] = index;
server_to_host_ids[statistics_info_.server_to_host_size_++] = id;
break;
}
}
}
void PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
MS_EXCEPTION_IF_NULL(device_to_host_ids);
MS_EXCEPTION_IF_NULL(host_to_server_index);
MS_EXCEPTION_IF_NULL(host_to_server_ids);
MS_EXCEPTION_IF_NULL(device_to_host_index);
auto host_hash_map = embedding_host_cache_->host_hash_map_;
MS_EXCEPTION_IF_NULL(host_hash_map);
int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
auto iter = host_hash_map->id_iter(swap_device_to_host_id);
if (host_hash_map->IsIdExist(iter)) {
auto index = iter->second;
if (host_hash_map->hash_step(index) != data_step_) {
host_hash_map->set_hash_step(index, data_step_);
}
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
} else {
while (true) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
if (index == INVALID_INDEX_VALUE) {
WaitGraphRun();
continue;
}
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
break;
}
}
}
void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size,
const float *input_addr, const int *indices_addr, float *output_addr) {
auto type_size = sizeof(float);
size_t lens = outer_dim_size * type_size;
for (size_t i = 0; i < indices_lens; ++i) {
int index = indices_addr[i];
if (index >= 0 && index < SizeToInt(first_dim_size)) {
size_t pos = index * outer_dim_size;
auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
}
} else {
auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
}
}
output_addr += outer_dim_size;
}
}
void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr) {
size_t first_dim_size = host_cache_vocab_size_;
size_t outer_dim_size = embedding_size;
size_t thread_num = indices_lens / 10000 + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;
MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
for (i = 0; i < thread_num; i++) {
if (task_offset >= indices_lens) {
break;
}
MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size,
hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size);
task_offset += task_proc_lens;
if (task_offset + task_proc_lens > indices_lens) {
task_proc_lens = indices_lens - task_offset;
}
}
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
}
void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
float *insert_data, float *hash_table_addr) {
size_t first_dim_size = host_cache_vocab_size_;
size_t thread_num = insert_indices_size / 10000 + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;
auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
int *insert_indices, float *insert_data, float *hash_table_addr) {
auto type_size = sizeof(float);
size_t lens = outer_dim_size * type_size;
for (size_t i = 0; i < insert_indices_size; ++i) {
int index = insert_indices[i];
if (index >= 0 && index < SizeToInt(first_dim_size)) {
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed.";
}
}
}
};
for (i = 0; i < thread_num; i++) {
if (task_offset >= insert_indices_size) {
break;
}
MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size,
insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr);
task_offset += task_proc_lens;
if (task_offset + task_proc_lens > insert_indices_size) {
task_proc_lens = insert_indices_size - task_offset;
}
}
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
}
void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get();
auto swap_indices_size = statistics_info_.host_to_device_size_;
if (swap_indices_size == 0) {
return;
}
auto embedding_size = hash_info.embedding_size;
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index,
swap_out_data.get());
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_,
swap_out_data.get(),
swap_indices_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_indices_size);
}
void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
auto swap_indices_size = statistics_info_.device_to_host_size_;
auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get();
if (swap_indices_size == 0) {
return;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto embedding_size = hash_info.embedding_size;
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_indices_size);
embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(),
embedding_device_cache_->hash_swap_value_addr_,
swap_indices_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->SynchronizeStream();
InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
swap_out_data.get(), host_hash_table_addr);
}
void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
auto host_to_server_index = embedding_host_cache_->host_to_server_index.get();
auto swap_indices_size = statistics_info_.host_to_server_size_;
if (swap_indices_size == 0) {
return;
}
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
::ps::SArray<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
swap_out_data.resize(swap_indices_size * embedding_size);
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
swap_out_data.data());
auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
}
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
}
void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
auto swap_indices_size = statistics_info_.server_to_host_size_;
auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
auto server_to_host_index = embedding_host_cache_->server_to_host_index.get();
if (swap_indices_size == 0) {
return;
}
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto embedding_size = hash_info.embedding_size;
::ps::SArray<int> lengths{swap_indices_size};
::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0);
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
}
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(),
host_hash_table_addr);
}
void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data,
const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(swap_out_index);
MS_EXCEPTION_IF_NULL(swap_out_data);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
auto swap_out_index_size = statistics_info_.device_to_host_size_;
if (swap_out_index_size == 0) {
return;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto embedding_size = hash_info.embedding_size;
swap_out_data->resize(swap_out_index_size * embedding_size);
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index,
swap_out_index_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_out_index_size);
embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(),
embedding_device_cache_->hash_swap_value_addr_,
swap_out_index_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->RecordEvent();
}
void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info,
size_t key) {
MS_EXCEPTION_IF_NULL(swap_in_ids);
MS_EXCEPTION_IF_NULL(swap_in_index);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
auto swap_in_ids_size = statistics_info_.host_to_device_size_;
if (swap_in_ids_size == 0) {
return;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto embedding_size = hash_info.embedding_size;
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
::ps::SArray<int> lengths{swap_in_ids_size};
::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0);
::ps::SArray<int> lookup_ids(swap_in_ids_size, 0);
auto copy_len = swap_in_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
}
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
// Hash swap-in in device.
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_,
lookup_result.data(),
swap_in_ids_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index,
swap_in_ids_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_in_ids_size);
}
void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_EXCEPTION_IF_NULL(swap_out_ids);
auto swap_out_ids_size = statistics_info_.device_to_host_size_;
if (swap_out_ids_size == 0) {
return;
}
::ps::SArray<int> lookup_ids(swap_out_ids_size, 0);
auto copy_len = swap_out_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
}
// Need synchronize event to ensure that the swap-out in device is completed.
embedding_device_cache_->cache_->SynchronizeEvent();
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
}
void PsCacheManager::DumpHashTables() const {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t cache_vocab_size = item.second.cache_vocab_size;
size_t embedding_size = item.second.embedding_size;
size_t vocab_size = item.second.vocab_size;
MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size
<< " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr)
<< " || " << reinterpret_cast<void *>(item.second.host_address.get());
float *output = new float[item.second.device_address.size / 4];
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
item.second.device_address.size);
embedding_device_cache_->cache_->SynchronizeStream();
for (size_t i = 0; i < cache_vocab_size; i++) {
for (size_t j = 0; j < embedding_size; j++) {
std::cout << output[i * embedding_size + j] << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
delete[] output;
}
}
} // namespace ps
} // namespace mindspore

View File

@ -49,6 +49,7 @@ struct HashTableInfo {
size_t vocab_size{0};
Address device_address{nullptr, 0};
std::shared_ptr<int[]> host_address{nullptr};
ParamInitInfo param_init_info_;
};
struct EmbeddingDeviceCache {
@ -158,6 +159,8 @@ class PsCacheManager {
void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const;
void AddEmbeddingTable() const;
bool initialized_ps_cache_{false};
std::string channel_name_;
@ -167,6 +170,7 @@ class PsCacheManager {
size_t data_step_{0};
std::mutex data_mutex_;
std::condition_variable data_prase_;
std::condition_variable insert_init_info_;
std::map<std::string, HashTableInfo> hash_tables_;
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
@ -178,6 +182,8 @@ class PsCacheManager {
size_t batch_elements_{0};
PsCacheStatisticsInfo statistics_info_;
std::pair<size_t, size_t> range_bound_;
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
};
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace ps {
class PsDataPrefetch {
class EXPORT PsDataPrefetch {
public:
EXPORT static PsDataPrefetch &GetInstance() {
static PsDataPrefetch instance;

View File

@ -17,6 +17,11 @@
#include "ps/ps_context.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "backend/kernel_compiler/kernel.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
namespace mindspore {
namespace ps {
@ -80,5 +85,43 @@ bool PSContext::is_role_sched() const { return is_sched_; }
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
int PSContext::ps_rank_id() const { return rank_id_; }
void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
#endif
}
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
#endif
}
void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
#endif
}
void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
#endif
}
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
#endif
}
void PSContext::set_cache_enable(bool cache_enable) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
#endif
}
} // namespace ps
} // namespace mindspore

View File

@ -44,6 +44,14 @@ class PSContext {
bool is_role_sched() const;
void SetPSRankId(int rank_id);
int ps_rank_id() const;
void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) const;
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size) const;
void InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const;
void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
void set_cache_enable(bool cache_enable) const;
private:
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}

View File

@ -0,0 +1,71 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/random_normal/random_normal.h"
#include <iostream>
#include <thread>
#include <memory>
#include "utils/convert_utils_base.h"
#include "pybind_api/random_normal/random_cpu_kernel.h"
namespace mindspore {
namespace ps {
bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed,
float *output_data) {
if (out_shape.size() == 0) {
std::cout << "output data shape is error" << std::endl;
}
int64_t total_count = 1;
for (uint32_t i = 0; i < out_shape.size(); i++) {
total_count *= SizeToLong(out_shape[i]);
}
uint32_t thread_num = 16;
if (total_count <= thread_num) {
thread_num = 1;
}
float *start_ptr = output_data;
if (start_ptr == nullptr) {
std::cout << "start_ptr is nullptr" << std::endl;
return false;
}
int64_t batchSize = total_count / thread_num;
std::vector<std::thread> threads(thread_num);
int64_t seed = SizeToLong(global_seed);
int64_t seed2 = SizeToLong(op_seed);
seed = (seed == 0 && seed2 == 0) ? clock() : seed;
PhiloxGenerator generator = PhiloxGenerator(seed, seed2);
if (thread_num != 1) {
for (uint32_t i = 0; i < thread_num - 1; i++) {
float *offset_ptr = start_ptr + batchSize * i;
threads[i] =
std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, offset_ptr, batchSize, i);
}
float *offset_ptr = start_ptr + batchSize * (thread_num - 1);
threads[thread_num - 1] = std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator,
offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1);
} else {
threads[0] =
std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, start_ptr, total_count, 0);
}
for (uint32_t i = 0; i < thread_num; i++) {
threads[i].join();
}
for (int64_t i = 0; i < total_count; i++) {
output_data[i] *= stddev;
}
return true;
}
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_
#define MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_
#include <vector>
namespace mindspore {
namespace ps {
bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed,
float *output_data);
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_

View File

@ -26,6 +26,15 @@
namespace mindspore {
namespace ps {
enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 };
struct ParamInitInfo {
ParamType param_type_{kUnKnown};
size_t global_seed_{0};
size_t op_seed_{0};
float init_val_{0};
};
class Util {
public:
static bool IsParamServerMode();

View File

@ -32,6 +32,7 @@
#include "ps/common.h"
#include "ps/worker_proxy.h"
#include "utils/shape_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
namespace mindspore {
namespace ps {
@ -47,15 +48,19 @@ class Worker {
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
void Pull(const size_t key, void *dev_addr, const size_t size);
size_t SetParamKey(const std::string &param_name);
size_t GetParamKey(const std::string &param_name);
void SetParamInitInServer(const std::string &param_name, bool init_in_server);
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const ShapeVector &sizes);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes);
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd);
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals);
bool running() { return running_; }
void Finalize();
private:
@ -65,7 +70,6 @@ class Worker {
Worker &operator=(const Worker &) = delete;
bool IsKeyInit(const size_t key);
size_t GetParamKey(const std::string &param_name);
void InitPSOptimId(const size_t param_key);
void InitPSOptimInputShapes(const size_t key);
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
@ -187,6 +191,12 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd);
}
template <typename T>
void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals) {
kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals);
}
template <typename T>
void Worker<T>::Finalize() {
if (running_) {
@ -286,7 +296,7 @@ size_t Worker<T>::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey;
if (param_to_key_.find(param_name) != param_to_key_.end()) {
key = param_to_key_[param_name];
MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key;
MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
}
return key;
}
@ -310,8 +320,7 @@ void Worker<T>::InitPSOptimId(const size_t param_key) {
}
template <typename T>
void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes,
const ShapeVector &sizes) {
void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) {
bool has_init = IsKeyInit(keys[0]);
if (has_init) {
MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized.";
@ -319,7 +328,7 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
}
::ps::SArray<T> shapes_val;
for (auto dim : shapes) {
shapes_val.push_back(static_cast<T>(dim));
shapes_val.push_back(dim);
}
std::vector<int> sizes_int;
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
@ -337,9 +346,6 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::
const std::string &param_name = pk_node->fullname_with_scope();
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size;
}
size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
@ -357,11 +363,17 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
<< ", whether init in server: " << init_in_server;
kv_worker_->AddKeyToServerId(param_key);
if (!init_in_server) {
InitPSParamData({param_key}, param_data, param_size);
if (!PsDataPrefetch::GetInstance().cache_enable()) {
if (!init_in_server) {
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
<< param_size;
}
InitPSParamData({param_key}, param_data, param_size);
}
InitPSOptimId(param_key);
InitPSOptimInputShapes(param_key);
}
InitPSOptimId(param_key);
InitPSOptimInputShapes(param_key);
}
}

View File

@ -45,6 +45,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id)
: Worker(app_id, customer_id) {
server_num_ = ::ps::NumServers();
MS_LOG(INFO) << "Server num:" << server_num_;
PSContext::instance()->SetPSRankId(::ps::MyRank());
using std::placeholders::_1;
using std::placeholders::_2;
@ -60,6 +61,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5);
round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5);
worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5);
update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5);
}
~WorkerProxy() override = default;
@ -70,6 +72,8 @@ class WorkerProxy : public ::ps::KVWorker<T> {
const Callback &cb = nullptr, int64_t priority = 0);
int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0);
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0);
bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key);
void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
@ -98,6 +102,9 @@ class WorkerProxy : public ::ps::KVWorker<T> {
void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs);
void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs);
void ProcessLookupResult(const ::ps::Message &msg);
void ProcessResponse(const ::ps::Message &msg);
void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs,
@ -122,6 +129,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
Slicer broadcast_slicer_;
Slicer round_robin_slicer_;
Slicer worker_init_embedding_slicer_;
Slicer update_embedding_slicer_;
std::unordered_map<int64_t, Callback> lookup_callbacks_;
std::unordered_map<int64_t, Callback> general_callbacks_;
std::unordered_map<int64_t, int64_t> expected_result_count_;
@ -195,6 +203,24 @@ int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys,
return ts;
}
template <typename T>
void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) {
int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.lens = lookup_ids;
kvs.vals = vals;
kvs.priority = priority;
expected_result_count_[ts] = 0;
Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_);
if (expected_result_count_[ts] < server_num_) {
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
}
general_customer_->WaitRequest(ts);
expected_result_count_.erase(ts);
}
template <typename T>
bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
::ps::SArray<T> result(1, 0);
@ -724,6 +750,47 @@ void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KV
}
}
template <typename T>
void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send,
const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(sliced);
T *embedding_vals = send.vals.data();
int *lookup_ids = send.lens.data();
size_t val_size = send.vals.size();
size_t id_size = send.lens.size();
size_t embedding_dim = val_size / id_size;
const Key &key = send.keys[0];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
sliced->resize(ranges.size());
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
auto &kvs = sliced->at(i).second;
kvs.keys.push_back(key);
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
if (lookup_id >= begin && lookup_id <= end) {
kvs.keys.push_back(lookup_id);
for (size_t k = 0; k < embedding_dim; k++) {
kvs.vals.push_back(embedding_vals[j * embedding_dim + k]);
}
}
}
if (kvs.keys.size() <= 1) {
sliced->at(i).first = false;
} else {
sliced->at(i).first = true;
expected_result_count_[timestamp] += 1;
}
}
}
template <typename T>
void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
int64_t ts = msg.meta.timestamp;

View File

@ -18,6 +18,7 @@
#include <thread>
#include <memory>
#include "runtime/device/cpu/cpu_device_address.h"
#include "ir/tensor.h"
namespace mindspore {
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2,

View File

@ -19,8 +19,7 @@
#include "pybind_api/random_normal/philox_generator.h"
#include "pybind11/pybind11.h"
#include "pybind_api/api_register.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "utils/log_adapter.h"
namespace py = pybind11;

View File

@ -55,6 +55,7 @@ class AscendKernelRuntime : public KernelRuntime {
bool SyncStream() override;
void SetContext() override;
void CreateContext() override;
void *context() const override { return rt_context_; }
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,

View File

@ -18,6 +18,7 @@
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "utils/ms_context.h"
#include "utils/convert_utils.h"
#include "ps/ps_cache/ps_cache_manager.h"
namespace mindspore {
namespace device {
namespace gpu {
@ -38,6 +39,9 @@ void GPUMemoryManager::MallocDeviceMemory() {
MS_EXCEPTION_IF_NULL(context_ptr);
// If use the dynamic memory pool, then alloc the first memory block to init.
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
if (ps::ps_cache_instance.initialized_ps_cache()) {
return;
}
auto device_addr = MallocMemFromMemPool(1);
if (!device_addr) {
MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";

View File

@ -30,6 +30,10 @@
#include "utils/ms_utils.h"
#include "utils/shape_utils.h"
#include "utils/utils.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
@ -331,15 +335,27 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
continue;
}
DeviceAddressPtr device_address = nullptr;
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
const std::string &param_name = item->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
MS_EXCEPTION_IF_NULL(address.addr);
device_address =
CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
AnfAlgo::SetOutputAddr(device_address, index, item.get());
continue;
}
#endif
auto tensor_size = CountNodeDeviceMemorySize(item, index);
auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope();
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
}
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
<< " index: " << index << " size: " << tensor_size;
AnfAlgo::SetOutputAddr(address, index, item.get());
AnfAlgo::SetOutputAddr(device_address, index, item.get());
}
}
MS_LOG(INFO) << "AssignStaticMemoryInput end";

View File

@ -78,6 +78,7 @@ class KernelRuntime {
virtual void ClearGlobalIdleMem() {}
virtual void CreateContext() {}
virtual void SetContext() {}
virtual void *context() const { return nullptr; }
uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
return mem_manager_->MallocMem(type, size, address);
}

View File

@ -15,6 +15,7 @@
"""Parameter for cell."""
from copy import copy
import numbers
from .._c_expression import ParamInfo
from .._c_expression import MetaTensor as MetaTensor_
from . import dtype as mstype
@ -23,7 +24,10 @@ from .tensor import Tensor, MetaTensor
from .._checkparam import Validator
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
from ..parallel._ps_context import _reinsert_hash_table_size
from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
from .seed import _get_global_and_op_seed
__all__ = ['Parameter', 'ParameterTuple']
@ -35,6 +39,18 @@ def _is_in_parallel_mode():
"""Get parallel mode."""
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
def init_to_value(init):
"""Get value of initializer."""
if isinstance(init, str):
if init == 'zeros':
return 0.0
if init == 'ones':
return 1.0
raise ValueError("init should be one of values in 'zeros', 'ones'.")
if isinstance(init, numbers.Number):
return float(init)
raise ValueError("init should be number or string")
class Parameter(MetaTensor_):
"""
@ -118,6 +134,8 @@ class Parameter(MetaTensor_):
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False):
self._param_info = ParamInfo()
self.init_in_server = False
self.cache_enable = False
self.name = name
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
@ -129,7 +147,6 @@ class Parameter(MetaTensor_):
self._sliced = False
self.is_param_ps = False
self._cast_type = None
self.init_in_server = False
self._unique = False
self.is_in_parallel = _is_in_parallel_mode()
if isinstance(default_input, (MetaTensor, Tensor)):
@ -155,7 +172,7 @@ class Parameter(MetaTensor_):
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, MetaTensor):
if _is_in_parallel_mode() or _is_role_worker():
if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched():
# do not init data while in auto parallel.
return (MetaTensor_, data.dtype, data.shape)
data = data.to_tensor()
@ -189,18 +206,18 @@ class Parameter(MetaTensor_):
init_in_server (bool): Whether trainable parameter updated by parameter server is
initialized on server. Default: False.
"""
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
if init_in_server and (not self.name.endswith("embedding_table")):
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
"sparse operator support initialization in server.".format(self.name))
self.is_param_ps = True
self.init_in_server = init_in_server
self._param_info.init_in_server = init_in_server
else:
if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()):
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
1. set_ps_context(enable_ps=True) \
2. export MS_ROLE environment variable.")
if init_in_server and (not self.name.endswith("embedding_table")):
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
"sparse operator support initialization in server.".format(self.name))
self.is_param_ps = True
self.init_in_server = init_in_server
self._param_info.init_in_server = init_in_server
@property
def inited_param(self):
@ -238,6 +255,13 @@ class Parameter(MetaTensor_):
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
else:
raise ValueError("The type of the name should be `str` or `None`.")
if _is_role_worker() and self.cache_enable:
if len(self.shape) != 2:
raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
.format(self.name, len(self.shape)))
_reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1])
self._param_info.name = name_
@property
@ -297,6 +321,7 @@ class Parameter(MetaTensor_):
x.is_init = False
x.is_param_ps = self.is_param_ps
x.init_in_server = self.init_in_server
x.cache_enable = self.cache_enable
if init != 'same':
shape = self.shape
dtype = self.dtype
@ -431,15 +456,18 @@ class Parameter(MetaTensor_):
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1]))
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
if _is_role_worker():
if _is_role_worker() or _is_role_sched():
data = self.init_mode.to_tensor(0, [1])
else:
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else:
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else:
if _is_role_worker() and self.cache_enable:
global_seed, op_seed = _get_global_and_op_seed()
_insert_weight_init_info(self.name, global_seed, op_seed)
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
if _is_role_worker():
if _is_role_worker() or _is_role_sched():
data = self.init_mode.to_tensor(0, [1])
else:
data = self.init_mode.to_tensor()
@ -502,6 +530,16 @@ class ParameterTuple(tuple):
x1 = x.clone(init)
x1.name = prefix + "." + x1.name
new.append(x1)
if not x1.cache_enable:
continue
if not x1.name.endswith("embedding_table"):
raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of "
"sparse operator support enable cache.".format(x1.name))
if _is_role_worker():
_clone_hash_table(x.name, x1.name)
_insert_accumu_init_info(x1.name, init_to_value(init))
return ParameterTuple(new)
def __parameter_tuple__(self):

View File

@ -195,6 +195,20 @@ def _get_op_seed(op_seed, kernel_name):
return _KERNEL_SEED[(kernel_name, op_seed)]
def _get_global_and_op_seed():
"""Get global_seed and op_seed."""
global_seed = get_seed()
op_seed = get_seed()
if global_seed == 0:
global_seed = DEFAULT_GRAPH_SEED
elif global_seed is None:
global_seed = 0
Validator.check_non_negative_int(op_seed, "seed", "init")
temp_seed = _get_op_seed(op_seed, "init")
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
return seeds
def _get_graph_seed(op_seed, kernel_name):
"""
Get the graph-level seed.

View File

@ -73,6 +73,8 @@ inline size_t FloatToSize(float u) {
}
inline float IntToFloat(int32_t v) { return static_cast<float>(v); }
inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
inline double FloatToDouble(float v) { return static_cast<double>(v); }

View File

@ -22,6 +22,7 @@ from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
@ -156,6 +157,7 @@ class EmbeddingLookup(Cell):
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings.
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@ -185,7 +187,7 @@ class EmbeddingLookup(Cell):
def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None,
max_norm=None, sparse=True):
max_norm=None, sparse=True, vocab_cache_size=0):
super(EmbeddingLookup, self).__init__()
self.target = target
if target not in ('CPU', 'DEVICE'):
@ -199,11 +201,23 @@ class EmbeddingLookup(Cell):
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.cache_enable = self.vocab_cache_size > 0
if self.cache_enable:
if is_auto_parallel:
self.vocab_cache_size = self.vocab_cache_size * get_group_size()
self.vocab_size = self.vocab_cache_size
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
if self.cache_enable:
self.embedding_table.cache_enable = True
_set_cache_enable(True)
if _is_role_worker():
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
self.forward_unique = False
self.gather_revert = P.GatherV2()
self.unique = P.Unique().shard(((1,),))
@ -222,7 +236,7 @@ class EmbeddingLookup(Cell):
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel:
if target == 'DEVICE':
if target == 'DEVICE' and not self.cache_enable:
indices_shape_size = 1
self.gather_revert.shard(((1, 1), (get_group_size(),)))
self.forward_unique = True

View File

@ -88,14 +88,14 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter):
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape()
shapes = (op_shape(param), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
@ -158,12 +158,13 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter):
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
moment1, moment2, ps_parameter, cache_enable):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape()
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
@ -338,12 +339,12 @@ class Adam(Optimizer):
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2, self.ps_parameters)
lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
else:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2, self.ps_parameters)
gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
return success
@Optimizer.target.setter

View File

@ -24,14 +24,14 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
"RowTensor", "Tensor", "Tensor", "Bool")
"RowTensor", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
gradient, weight, moment, ps_parameter):
gradient, weight, moment, ps_parameter, cache_enable):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape()
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((values, indices), shapes), weight))
@ -41,12 +41,12 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, le
@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Bool")
"Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
gradient, weight, moment, ps_parameter):
gradient, weight, moment, ps_parameter, cache_enable):
"""Apply ftrl optimizer to the weight parameter."""
success = True
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape()
success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power),
(op_shape(weight), op_shape(moment), op_shape(linear))), weight))
@ -185,7 +185,7 @@ class FTRL(Optimizer):
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.l1, self.l2, self.lr_power, lr),
linear, grads, params, moments, self.ps_parameters)
linear, grads, params, moments, self.ps_parameters, self.cache_enable)
return success
@Optimizer.target.setter

View File

@ -156,6 +156,8 @@ class Optimizer(Cell):
break
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
ps_cache_filter = lambda x: x.cache_enable
self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters)
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.need_scale = loss_scale != 1.0
self.global_step_increase_tensor = Tensor(1, mstype.int32)

View File

@ -117,3 +117,21 @@ def _is_role_pserver():
def _is_role_sched():
return ps_context().is_role_sched()
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
def _insert_weight_init_info(name, global_seed, op_seed):
ps_context().insert_weight_init_info(name, global_seed, op_seed)
def _insert_accumu_init_info(name, init_val):
ps_context().insert_accumu_init_info(name, init_val)
def _clone_hash_table(dest_param_name, src_param_name):
ps_context().clone_hash_table(dest_param_name, src_param_name)
def _set_cache_enable(cache_enable):
ps_context().set_cache_enable(cache_enable)

View File

@ -92,6 +92,10 @@ def connect_network_with_dataset(network, dataset_helper):
if isinstance(dataset_iter, _DatasetIterNormal):
raise RuntimeError("Dataset should be connected with network only in sink mode.")
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
return network
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
and context.get_context("device_target") == "Ascend" \
@ -166,14 +170,14 @@ class DatasetHelper:
iterclass = _DatasetIterGE
else:
if context.get_context("mode") == context.GRAPH_MODE:
if context.get_context("device_target") == "Ascend":
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
iterclass = _DatasetIterPSServer
elif ms_role == "MS_WORKER":
iterclass = _DatasetIterPSWork
elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"):
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "GPU":
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
iterclass = _DatasetIterPSLite
else:
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "CPU":
raise RuntimeError(
"Currently dataset sink mode is not supported when the device target is CPU.")
@ -218,7 +222,10 @@ class _DatasetIter:
if not hasattr(dataset, '__transfer_dataset__'):
if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__
ms_role = os.getenv("MS_ROLE")
# PS mode does not support loop sink and need get the real sink size.
if ms_role != "MS_WORKER":
self.sink_size = dataset.__loop_size__
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context(
"device_target") == "Ascend")
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
@ -260,8 +267,12 @@ class _DatasetIter:
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
ms_role = os.getenv("MS_ROLE")
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
elif ms_role == "MS_WORKER":
# PS mode does not support loop sink.
sink_size = 1
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \
or context.get_context("device_target") == "GPU":
@ -311,9 +322,6 @@ class _DatasetIterMSLoopSink(_DatasetIter):
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
self.sink_count = self.get_sink_count(dataset)
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
@ -341,8 +349,8 @@ class _DatasetIterMS(_DatasetIter):
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
class _DatasetIterPSLite(_DatasetIter):
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
class _DatasetIterPSServer(_DatasetIter):
"""Iter for context on MS_PSERVER or MS_SCHED"""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
@ -355,6 +363,20 @@ class _DatasetIterPSLite(_DatasetIter):
self.op = op
class _DatasetIterPSWork(_DatasetIter):
"""Iter for context on MS_WORKER"""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
def op():
return tuple()
self.op = op
class _DatasetIterNormal:
"""Iter for normal(non sink) mode, feed the data from host."""

View File

@ -30,6 +30,7 @@ def argparse_init():
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.")
parser.add_argument("--vocab_cache_size", type=int, default=0, help="The total features of hash table.")
parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.")
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128],
help="The dimension of all deep layers.")
@ -66,6 +67,7 @@ class WideDeepConfig():
self.eval_batch_size = 16000
self.field_size = 39
self.vocab_size = 200000
self.vocab_cache_size = 100000
self.emb_dim = 80
self.deep_layer_dim = [1024, 512, 256, 128]
self.deep_layer_act = 'relu'
@ -103,6 +105,7 @@ class WideDeepConfig():
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size
self.vocab_size = args.vocab_size
self.vocab_cache_size = args.vocab_cache_size
self.emb_dim = args.emb_dim
self.deep_layer_dim = args.deep_layer_dim
self.deep_layer_act = args.deep_layer_act

View File

@ -147,6 +147,7 @@ class WideDeepModel(nn.Cell):
sparse = config.sparse
self.field_size = config.field_size
self.vocab_size = config.vocab_size
self.vocab_cache_size = config.vocab_cache_size
self.emb_dim = config.emb_dim
self.deep_layer_dims_list = config.deep_layer_dim
self.deep_layer_act = config.deep_layer_act
@ -237,8 +238,20 @@ class WideDeepModel(nn.Cell):
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
self.embedding_table = self.deep_embeddinglookup.embedding_table
elif parameter_server:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
cache_enable = self.vocab_cache_size > 0
target = 'DEVICE' if cache_enable else 'CPU'
if is_auto_parallel and config.full_batch and cache_enable:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE,
sparse=sparse, vocab_cache_size=self.vocab_cache_size)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE,
sparse=sparse, vocab_cache_size=self.vocab_cache_size)
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
sparse=sparse, vocab_cache_size=self.vocab_cache_size)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, sparse=sparse,
vocab_cache_size=self.vocab_cache_size)
self.embedding_table = self.deep_embeddinglookup.embedding_table
self.deep_embeddinglookup.embedding_table.set_param_ps()
self.wide_embeddinglookup.embedding_table.set_param_ps()
@ -344,7 +357,7 @@ class TrainStepWrap(nn.Cell):
"""
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False,
sparse=False):
sparse=False, cache_enable=False):
super(TrainStepWrap, self).__init__()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
@ -361,7 +374,7 @@ class TrainStepWrap(nn.Cell):
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
if (sparse and is_auto_parallel) or parameter_server:
if (sparse and is_auto_parallel) or (parameter_server and not cache_enable):
self.optimizer_d = LazyAdam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
@ -417,10 +430,17 @@ class TrainStepWrap(nn.Cell):
class PredictWithSigmoid(nn.Cell):
"""
Predict definition
"""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__()
self.network = network
self.sigmoid = P.Sigmoid()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
self.sigmoid.shard(((1, 1),))
def construct(self, batch_ids, batch_wts, labels):
logits, _, = self.network(batch_ids, batch_wts)

View File

@ -39,7 +39,8 @@ def get_WideDeep_net(config):
"""
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server))
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server),
cache_enable=bool(config.vocab_cache_size > 0))
eval_net = PredictWithSigmoid(WideDeep_net)
return train_net, eval_net
@ -81,6 +82,7 @@ def train_and_eval(config):
else:
dataset_type = DataType.H5
parameter_server = bool(config.parameter_server)
cache_enable = bool(config.vocab_cache_size > 0)
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(),
@ -111,7 +113,7 @@ def train_and_eval(config):
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train,
callbacks=callback_list,
dataset_sink_mode=(not parameter_server))
dataset_sink_mode=(parameter_server and cache_enable))
if __name__ == "__main__":