diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ee33ff39fc2..5939449692c 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1194,6 +1194,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); + std::vector shape_init_in_server = {1}; for (size_t i = 0; i < inputs.size(); ++i) { auto tensor = inputs[i]; MS_EXCEPTION_IF_NULL(tensor); @@ -1201,8 +1202,13 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { auto pk_node = input_node->cast(); + bool init_in_server = false; + if (tensor->shape_c() == shape_init_in_server) { + MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope(); + init_in_server = true; + } mindspore::parallel::ps::Worker::GetInstance().InitPSParamAndOptim( - pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes())); + pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server); } } ps_init_ = true; diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 429cf59c3c1..48d70d027fd 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -542,6 +542,10 @@ inline bool ParameterServer::ReadyForUpdateWeights() { template inline bool ParameterServer::ReadyForAccumGrads() { + if (weights_.empty()) { + MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " + "kInitWeightsCmd command. 2.The Server failed to initialize weights."; + } return grad_accum_count_ < weights_.size(); } diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index d7f0bb6df52..bde21904bb9 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -47,7 +47,8 @@ class Worker { void SetOptimInputShapes(size_t key, const std::vector &shape); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); - void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); + void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size, + bool init_in_server = false); void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); void Finalize(); @@ -240,7 +241,8 @@ void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vecto template // Initialize parameters and optimizer kernels of Parameter Server. -void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) { +void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size, + bool init_in_server) { size_t param_key = GetParamKey(param_name); if (param_key == kInvalidKey) { MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; @@ -248,9 +250,9 @@ void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_d } bool init = IsKeyInit(param_key); if (!init) { - MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; - // No need to push embedding table data to Parameter Server. - if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { + MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name + << ", whether init in server: " << init_in_server; + if (!init_in_server) { InitPSParamData({param_key}, param_data, param_size); } InitPSOptimId(param_key);