diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 60e88a9c10f..107a9e4752d 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -28,7 +28,9 @@ #include #include #include +#include #include +#include #include "ir/func_graph.h" #include "backend/session/session_basic.h" #include "backend/session/anf_runtime_algorithm.h" @@ -52,6 +54,7 @@ namespace mindspore { namespace parallel { namespace ps { using mindspore::kernel::ps::PServerKernel; +using AnfAlgo = session::AnfRuntimeAlgorithm; template class ParameterServer { public: @@ -126,6 +129,8 @@ class ParameterServer { void ResetGradAccumCount(); const CNodePtr GetCNode(const std::string &name) const; std::mutex &mutex(); + void GetEmbeddingTableParamPtr(); + void SyncEmbeddingTables(); size_t pserver_num_; size_t worker_num_; @@ -154,6 +159,7 @@ class ParameterServer { std::condition_variable apply_grads_cv_; std::unique_ptr thread_; + std::map embedding_tables_; friend class ServerHandler; }; @@ -329,6 +335,7 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) { InitOptimInfoBuilders(); ps_->set_request_handle(*handler_); thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); + GetEmbeddingTableParamPtr(); return true; } @@ -464,6 +471,7 @@ template void ParameterServer::Finalize() { running_ = false; apply_grads_cv_.notify_one(); + SyncEmbeddingTables(); } template @@ -647,6 +655,53 @@ inline std::mutex &ParameterServer::mutex() { return mutex_; } +template +void ParameterServer::GetEmbeddingTableParamPtr() { + MS_EXCEPTION_IF_NULL(func_graph_); + auto cnodes = func_graph_->GetOrderedCnodes(); + Key count = 0; + for (auto cnode : cnodes) { + std::string cnode_name = AnfAlgo::GetCNodeName(cnode); + if (cnode_name == kEmbeddingLookupOpName) { + auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); + MS_EXCEPTION_IF_NULL(embedding_table); + MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; + embedding_tables_.insert(std::make_pair(count, embedding_table->cast())); + count++; + } + } +} + +template +void ParameterServer::SyncEmbeddingTables() { + for (auto embedding_table : embedding_tables_) { + Key key = embedding_table.first; + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key; + } + auto lookup = embedding_lookup_ops_[key]; + const std::vector &input_shapes = lookup->input_sizes(); + std::vector new_tensor_shape(input_shapes.begin(), input_shapes.end()); + + tensor::TensorPtr new_tensor = std::make_shared(kNumberTypeFloat32, new_tensor_shape); + float *new_tensor_data_ptr = reinterpret_cast(new_tensor->data_c()); + size_t new_tensor_size = static_cast(new_tensor->data().nbytes()); + size_t embedding_table_size = weights_[key]->size() * sizeof(float); + if (new_tensor_size != embedding_table_size) { + MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size + << ", embedding_table size:" << embedding_table_size; + } + int ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + auto paramter_tensor_ptr = embedding_table.second->default_param(); + MS_EXCEPTION_IF_NULL(paramter_tensor_ptr); + paramter_tensor_ptr->cast()->AssignValue(*new_tensor); + } +} + template void ParameterServer::Run(const FuncGraphPtr &func_graph) { ::ps::Start(0); @@ -657,7 +712,6 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { Init(func_graph); thread_->join(); ::ps::Finalize(0, true); - exit(1); } } // namespace ps } // namespace parallel diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 93360b00b6b..b827ffe3455 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -15,6 +15,7 @@ # limitations under the License. # ============================================================================ """Providing interface methods.""" +import os import types from collections import OrderedDict from functools import wraps @@ -468,7 +469,7 @@ class _Executor: return self._executor.has_compiled(phase) def __call__(self, obj, *args, phase='predict'): - if context.get_context("precompile_only"): + if context.get_context("precompile_only") or os.getenv("MS_ROLE") == "MS_PSERVER": return None return self.run(obj, *args, phase=phase) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 12aaa87638b..d4dd29bd5d2 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -16,6 +16,7 @@ from collections.abc import Iterable import os +import sys import math import numpy as np @@ -496,6 +497,8 @@ class Model: self._loss_scale_manager.update_loss_scale(overflow) list_callback.step_end(run_context) + if os.getenv("MS_ROLE") == "MS_PSERVER": + sys.exit(0) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 636633a4e30..2bb3c056c2d 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -214,8 +214,8 @@ class WideDeepModel(nn.Cell): self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) self.embedding_table = self.deep_embeddinglookup.embedding_table - self.wide_w.set_param_ps() - self.embedding_table.set_param_ps() + self.deep_embeddinglookup.embedding_table.set_param_ps() + self.wide_embeddinglookup.embedding_table.set_param_ps() else: self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE') self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE') diff --git a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py index 0a7034c71a6..a596e13c0f3 100644 --- a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py +++ b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py @@ -14,6 +14,7 @@ # ============================================================================ import os +import sys import argparse import numpy as np @@ -82,8 +83,12 @@ def do_sparse_embedding(ps=False): for _ in range(epoch): data = Tensor(np.random.randint(0, 15, (32, 3), np.int32)) label = Tensor(np.random.randint(0, 9, (32), np.int32)) - loss = train_network(data, label).asnumpy() - losses.append(loss) + if envs.get("MS_ROLE") == "MS_PSERVER": + train_network(data, label) + sys.exit() + else: + loss = train_network(data, label).asnumpy() + losses.append(loss) print(losses) return losses