From 61551b85d801941513918121be6468ae9696a709 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Thu, 16 Jul 2020 22:42:18 +0800 Subject: [PATCH] incremental feature for ps --- .../cpu/ps/sparse_apply_adam_ps_kernel.cc | 2 + .../cpu/ps/sparse_apply_ftrl_ps_kernel.cc | 2 + mindspore/ccsrc/frontend/parallel/ps/common.h | 5 +- .../frontend/parallel/ps/optimizer_info.cc | 2 + .../frontend/parallel/ps/optimizer_info.h | 1 + .../parallel/ps/optimizer_info_builder.cc | 1 + .../frontend/parallel/ps/parameter_server.h | 117 +++++++++--------- mindspore/ccsrc/frontend/parallel/ps/worker.h | 12 +- .../ccsrc/frontend/parallel/ps/worker_proxy.h | 47 ++----- mindspore/ccsrc/utils/utils.h | 1 - mindspore/common/parameter.py | 14 ++- mindspore/communication/_comm_helper.py | 2 + mindspore/nn/cell.py | 4 +- mindspore/nn/optim/adam.py | 6 +- mindspore/nn/optim/momentum.py | 2 +- 15 files changed, 102 insertions(+), 116 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc index fa91f459472..4167c976747 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -62,6 +62,8 @@ void SparseApplyAdamPSKernel::InitKernel( */ workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index 93cd38c11b5..e350a9912af 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -52,6 +52,8 @@ void SparseApplyFtrlPSKernel::InitKernel( lr_power_ = -0.5; workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); } void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>>> &shapes) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h index bcd9a3a65ca..3921df2bd9c 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/common.h +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -72,13 +72,10 @@ using Values = ::ps::SArray; using ValuesPtr = std::shared_ptr; using Weight = ::ps::SArray; using Grad = ::ps::SArray; -using LookupIds = ::ps::SArray; +using LookupIds = ::ps::SArray; using Lengths = ::ps::SArray; using WeightPtr = std::shared_ptr; using GradPtr = std::shared_ptr; -// using EmbeddingTable = std::unordered_map; -// using EmbeddingTable = ::ps::SArray; -// using EmbeddingTablePtr = std::shared_ptr; using InputsShape = std::vector>>; using InputsShapePtr = std::shared_ptr>>>; } // namespace ps diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index cbfa5829837..6ec68b84c9e 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -57,6 +57,8 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { } } +void DenseOptimInfo::Reset() { memset_s(gradient()->addr, gradient()->size, 0x00, gradient()->size); } + void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { // Append grad data to the end float *accum_grad_data = reinterpret_cast(gradient()->addr); diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index ba9cb3f7d29..36d2b58a0e3 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -58,6 +58,7 @@ class DenseOptimInfo : public OptimizerInfo { ~DenseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; + void Reset() override; }; class SparseOptimInfo : public OptimizerInfo { diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index 7b6686ea869..b87b0016961 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -58,6 +58,7 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co AddressPtr accumulate = std::make_shared(); accumulate->addr = new float[weight->size()]; accumulate->size = weight->size() * sizeof(float); + memset_s(accumulate->addr, accumulate->size, 0x00, accumulate->size); AddressPtr learning_rate = std::make_shared(); learning_rate->addr = copy_data_ptr; learning_rate->size = lens[0] * sizeof(float); diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 182ff8d3606..7105b887db0 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -30,7 +30,6 @@ #include #include "ir/func_graph.h" #include "backend/session/session_basic.h" -#include "backend/session/kernel_graph.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/session_factory.h" #include "frontend/parallel/ps/common.h" @@ -70,24 +69,32 @@ class ParameterServer { ps_(new ::ps::KVServer(0)), handler_(nullptr), func_graph_(nullptr), - kernel_graph_(nullptr), sess_(nullptr), thread_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; ParameterServer &operator=(const ParameterServer &) = delete; - struct ServerHandler { + class ServerHandler { + public: explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + void Init(); void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); - void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); + + private: + void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleInitWeights(const ::ps::KVPairs &req_data); - void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); - void HandleInitInputsShape(const ::ps::KVPairs &req_data); - void HandleInitEmbeddings(const ::ps::KVPairs &req_data); + void HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res); + void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + ParameterServer *ps_; + typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res); + std::unordered_map handlers_; }; bool Init(const FuncGraphPtr &func_graph); @@ -103,7 +110,6 @@ class ParameterServer { WeightPtr weight(const Key &key); void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); int SumOfShapes(const std::vector &shapes) const; - size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); bool ReadyForUpdateWeights(); bool ReadyForAccumGrads(); void ResetGradAccumCount(); @@ -115,7 +121,6 @@ class ParameterServer { std::unique_ptr<::ps::KVServer> ps_; std::unique_ptr handler_; FuncGraphPtr func_graph_; - std::shared_ptr kernel_graph_; std::shared_ptr sess_; std::unordered_map> optimizers_; @@ -126,12 +131,7 @@ class ParameterServer { std::unordered_map weights_; std::unordered_map grads_; std::unordered_map grads_accum_counter_; - // std::unordered_map embeddings_; std::unordered_map> embedding_lookup_ops_; - std::unordered_map embedding_row_lens_; - - T learning_rate_; - T momentum_; std::mutex mutex_; std::condition_variable apply_grads_cv_; @@ -139,7 +139,7 @@ class ParameterServer { std::unique_ptr thread_; - friend struct ServerHandler; + friend class ServerHandler; }; class FuncGraph; @@ -147,33 +147,29 @@ template void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server) { ::ps::KVPairs res; - if (req_meta.cmd == kInitWeightsCmd) { - MS_LOG(ERROR) << "handle init weights cmd" << std::endl; - HandleInitWeights(req_data); - } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { - MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; - HandleInitWeightToOptimId(req_data); - } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { - MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; - HandleInitInputsShape(req_data); - } else if (req_meta.cmd == kInitEmbeddingsCmd) { - MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; - HandleInitEmbeddings(req_data); - } else if (req_meta.cmd == kEmbeddingLookupCmd) { - MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; - HandleEmbeddingLookup(req_meta, req_data, &res); + if (handlers_.count(req_meta.cmd) > 0) { + auto &handler_ptr = handlers_[req_meta.cmd]; + (this->*handler_ptr)(req_meta, req_data, &res); } else if (req_meta.push) { - MS_LOG(ERROR) << "handle push req cmd" << std::endl; - HandlePushReq(req_meta, req_data); + HandlePushReq(req_meta, req_data, &res); } else { - MS_LOG(ERROR) << "handle pull req cmd" << std::endl; HandlePullReq(req_meta, req_data, &res); } server->Response(req_meta, res); } template -void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::Init() { + handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights; + handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId; + handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape; + handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings; + handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; +} + +template +void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); } @@ -186,7 +182,8 @@ void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me } template -void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { size_t key_num = req_data.keys.size(); T *data_ptr = req_data.vals.data(); size_t pos = 0; @@ -205,7 +202,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs } template -void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { size_t key_num = req_data.keys.size(); for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; @@ -215,12 +214,14 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV } template -void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); } template -void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { std::shared_ptr>>> shapes = std::make_shared>>>(); std::shared_ptr> input_shape = std::make_shared>(); @@ -249,10 +250,10 @@ template void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { const Key &key = req_data.keys[0]; - for (size_t i = 0; i < req_data.vals.size(); i++) { - res->keys.push_back(req_data.vals[i]); + for (size_t i = 0; i < req_data.keys.size(); i++) { + res->keys.push_back(req_data.keys[i]); } - ps_->DoEmbeddingLookup(key, req_data.vals, res); + ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); } template @@ -268,6 +269,7 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; rank_id_ = ::ps::MyRank(); handler_.reset(new ServerHandler(this)); + handler_->Init(); InitOptimInfoBuilders(); @@ -364,7 +366,13 @@ void ParameterServer::InitEmbeddingTable( for (auto shape : input_shapes) { total_dims *= shape; } - WeightPtr embedding = std::make_shared(total_dims, 0.01); + + WeightPtr embedding = std::make_shared(total_dims, 0); + std::default_random_engine engine; + std::normal_distribution random(0, 0.01); + for (size_t i = 0; i < total_dims; i++) { + (*embedding)[i] = random(engine); + } weights_[key] = embedding; grads_accum_counter_[key] = 0; @@ -480,8 +488,13 @@ void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, inputs.push_back(indices); embedding_table->addr = table_ptr->data(); embedding_table->size = table_ptr->size() * sizeof(T); - indices->addr = lookup_ids.data(); - indices->size = lookup_ids.size() * sizeof(T); + + std::unique_ptr tmp_ids(new int[lookup_ids.size()]); + for (size_t i = 0; i < lookup_ids.size(); i++) { + tmp_ids[i] = static_cast(lookup_ids[i]); + } + indices->addr = tmp_ids.get(); + indices->size = lookup_ids.size() * sizeof(int); std::vector workspaces; std::vector outputs; @@ -506,20 +519,6 @@ int ParameterServer::SumOfShapes(const std::vector &shapes) const { return sum; } -template -size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { - size_t capacity = 0; - for (size_t i = 0; i < keys.size(); i++) { - Key key = keys[i]; - if (embedding_row_lens_.count(key) > 0) { - capacity += embedding_row_lens_[key] * lens[i]; - } else { - MS_LOG(ERROR) << "Invalid embedding lookup id " << key; - } - } - return capacity; -} - template inline bool ParameterServer::ReadyForUpdateWeights() { return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index 836377dd10e..411179ad900 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -155,9 +155,9 @@ void Worker::InitPSOptimInputShapes(const size_t key) { } } } - MS_LOG(ERROR) << "keys:" << keys; - MS_LOG(ERROR) << "shape_len:" << shape_len; - MS_LOG(ERROR) << "all_shape:" << all_shape; + MS_LOG(INFO) << "keys:" << keys; + MS_LOG(INFO) << "shape_len:" << shape_len; + MS_LOG(INFO) << "all_shape:" << all_shape; if (!init_keys_[key]) { init_keys_[key] = true; } @@ -191,7 +191,7 @@ size_t Worker::GetParamKey(const std::string ¶m_name) { size_t key = kInvalidKey; if (param_to_key_.find(param_name) != param_to_key_.end()) { key = param_to_key_[param_name]; - MS_LOG(ERROR) << "Get key of parameter " << param_name << " key is " << key; + MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key; } return key; } @@ -251,6 +251,10 @@ void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_d template void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + bool has_init = IsKeyInit(key); + if (has_init) { + return; + } kv_worker_->AddEmbeddingTable(key, row_count); } diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index 055c87670ed..f437f721831 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -156,30 +156,8 @@ int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps: auto &kvs = lookup_results_[ts]; mutex_.unlock(); - size_t total_len = 0; - std::unordered_map>> id_addr_map; - for (const auto &s : kvs) { - int offset = 0; - int len = s.vals.size() / s.keys.size(); - for (size_t i = 0; i < s.keys.size(); i++) { - const Key &key = s.keys[i]; - T *addr = s.vals.data() + offset; - offset += len; - total_len += len; - id_addr_map[key] = std::make_shared>(std::make_pair(addr, len)); - } - } - - T *result_addr = lookup_result->data(); - int offset = 0; - for (size_t i = 0; i < lookup_ids.size(); i++) { - auto &pair = id_addr_map[static_cast(lookup_ids[i])]; - auto ret = memcpy_s(result_addr + offset, pair->second, pair->first, pair->second); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - } - offset += pair->second; - } + auto &s = kvs[0]; + *lookup_result = s.vals; mutex_.lock(); lookup_results_.erase(ts); @@ -201,25 +179,16 @@ void WorkerProxy::LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, sliced->resize(ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { - const ::ps::Range &range = ranges[i]; - const auto &begin = range.begin(); - const auto &end = range.end(); - std::unordered_set unique_ids; auto &kvs = sliced->at(i).second; - for (size_t j = 0; j < id_size; j++) { - auto lookup_id = static_cast(lookup_ids[j]); - if (lookup_id >= begin && lookup_id <= end) { - unique_ids.insert(lookup_id); - } - } - for (const auto &lookup_id : unique_ids) { - kvs.vals.push_back(lookup_id); - } kvs.keys.push_back(key); - kvs.lens.push_back(kvs.vals.size()); + kvs.vals.push_back(0.0f); + for (size_t j = 0; j < id_size; j++) { + kvs.keys.push_back(lookup_ids[j]); + kvs.vals.push_back(0.0f); + } - if (kvs.vals.size() == 0) { + if (kvs.keys.size() <= 1) { sliced->at(i).first = false; } else { sliced->at(i).first = true; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index bebf51ed9dc..1ce981e2e3a 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -318,7 +318,6 @@ const std::set kOptOperatorSet = { kApplyProximalAdagradOpName, kApplyProximalGradientDescentOpName, kApplyRMSPropOpName, - kPushOpName, kPullOpName, }; diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 9405e7b2602..4f45f73e88b 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -61,6 +61,7 @@ class Parameter: self._is_init = False self._sliced = False self.is_param_ps = False + self.init_in_server = False if context.get_context("mode") == context.PYNATIVE_MODE: self.init_data() @@ -71,8 +72,9 @@ class Parameter: def __parameter__(self): """For parse check.""" - def set_param_ps(self): + def set_param_ps(self, init_in_server=False): self.is_param_ps = True + self.init_in_server = init_in_server @property def name(self): @@ -251,9 +253,15 @@ class Parameter: raise ValueError("The length of layout must be larger than 3! layout is {}." .format(layout)) slice_index = int(_get_slice_index(layout[0], layout[1])) - self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) + if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): + self.default_input = self.init_mode.to_tensor(0, [1]) + else: + self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) else: - self.default_input = self.init_mode.to_tensor() + if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): + self.default_input = self.init_mode.to_tensor(0, [1]) + else: + self.default_input = self.init_mode.to_tensor() self.init_mode = None if set_sliced: diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 5e1f7d06e72..1723fe9c979 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -113,6 +113,8 @@ def check_parameter_available(func): Wrapper. If not available, raise Error. """ def wrapper(*args, **kargs): + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + return func(*args, **kargs) group = None if "group" in kargs.keys(): group = kargs.get("group") diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 3eec96f0b5f..2209a3f9695 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -831,7 +831,7 @@ class Cell: self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self.enable_hook = True - def set_param_ps(self, recurse=True): + def set_param_ps(self, recurse=True, init_in_server=False): """ Set whether the trainable parameter is updated by parameter server. @@ -843,7 +843,7 @@ class Cell: """ params = self.trainable_params(recurse) for param in params: - param.set_param_ps() + param.set_param_ps(init_in_server) class GraphKernel(Cell): """ diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index e302cc599ac..ad7096a93f7 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -85,7 +85,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d return gradient -@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "IndexedSlices", +@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor", "Bool") def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): @@ -108,7 +108,7 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2 return success -@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", +@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): @@ -276,7 +276,7 @@ class Adam(Optimizer): self.beta2 = Tensor(beta2, mstype.float32) self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") - self.eps = eps + self.eps = Tensor(eps, mstype.float32) self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 61c06591944..a1730e7c679 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -32,7 +32,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, if ps_parameter: op_shape = P.Shape() _ps_pull = P.Pull() - _ps_push = P.Push("Momentum", []) + _ps_push = P.Push("ApplyMomentum", []) shapes = (op_shape(learning_rate), op_shape(gradient), op_shape(momentum)) success = F.depend(success, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight)) else: