diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h index 5b18fe8860..158b890929 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -31,8 +31,9 @@ class PServerKernel { ~PServerKernel() = default; PServerKernel(const PServerKernel &) = delete; PServerKernel &operator=(const PServerKernel &) = delete; - virtual void InitKernel(const std::shared_ptr>>> &) {} + virtual void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) {} virtual void ReInit(const std::shared_ptr>>> &) {} virtual bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) = 0; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc index 4167c97674..222a980fc6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace kernel { namespace ps { void SparseApplyAdamPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { + const CNodePtr &cnode, const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; std::vector &var_shape = *(shape_vec[0]); std::vector &m_shape = *(shape_vec[1]); @@ -55,11 +55,9 @@ void SparseApplyAdamPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; } - /* - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(cnode, "use_nesterov"); } - */ workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h index 0b35218abc..bd3e021a69 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -30,7 +30,8 @@ class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerK SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} ~SparseApplyAdamPSKernel() override = default; - void InitKernel(const std::shared_ptr>>> &) override; + void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) override; void ReInit(const std::shared_ptr>>> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index e350a9912a..afd676382f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { namespace ps { void SparseApplyFtrlPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { + const CNodePtr &cnode, const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; std::vector var_shape = *(shape_vec[0]); std::vector accum_shape = *(shape_vec[1]); @@ -46,10 +46,22 @@ void SparseApplyFtrlPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } - lr_ = 0.01; - l1_ = 1e-8; - l2_ = 1e-8; - lr_power_ = -0.5; + lr_ = AnfAlgo::GetNodeAttr(cnode, "lr"); + if (lr_ <= 0) { + MS_LOG(EXCEPTION) << "lr should be a positive scalar"; + } + l1_ = AnfAlgo::GetNodeAttr(cnode, "l1"); + if (l1_ < 0) { + MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; + } + l2_ = AnfAlgo::GetNodeAttr(cnode, "l2"); + if (l2_ < 0) { + MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; + } + lr_power_ = AnfAlgo::GetNodeAttr(cnode, "lr_power"); + if (lr_power_ > 0) { + MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; + } workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h index 033c0b9ac2..3a5dfc738e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -30,7 +30,8 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} ~SparseApplyFtrlPSKernel() override = default; - void InitKernel(const std::shared_ptr>>> &) override; + void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) override; void ReInit(const std::shared_ptr>>> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc index 50b7c0a157..03949b3685 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace kernel { namespace ps { void SparseApplyLazyAdamPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { + const CNodePtr &cnode, const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; std::vector &var_shape = *(shape_vec[0]); std::vector &m_shape = *(shape_vec[1]); @@ -55,11 +55,9 @@ void SparseApplyLazyAdamPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; } - /* - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(cnode, "use_nesterov"); } - */ workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h index 9a8e3d0475..595f2ab6a3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h @@ -30,7 +30,8 @@ class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} ~SparseApplyLazyAdamPSKernel() override = default; - void InitKernel(const std::shared_ptr>>> &) override; + void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) override; void ReInit(const std::shared_ptr>>> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h index dc79441d20..a072a65001 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/common.h +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -57,15 +57,20 @@ constexpr char kMomentum[] = "momentum"; constexpr char kApplyMomentum[] = "ApplyMomentum"; constexpr char kSparseAdam[] = "Adam"; constexpr char kSparseFtrl[] = "Ftrl"; +constexpr char kApplyMomentumOp[] = "Momentum"; +constexpr char kSparseAdamOp[] = "Adam"; +constexpr char kSparseFtrlOp[] = "FTRL"; constexpr int kInitWeightsCmd = 10; constexpr int kInitWeightToOptimIdCmd = 11; constexpr int kInitOptimInputsShapeCmd = 12; +constexpr int kInitKeyToPushNodeIdCmd = 13; constexpr int kInitEmbeddingsCmd = 20; constexpr int kEmbeddingLookupCmd = 30; constexpr int kFinalizeCmd = 40; constexpr size_t kInvalidKey = UINT64_MAX; +constexpr int kInvalidID = -1; using Key = ::ps::Key; using Keys = ::ps::SArray; diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h old mode 100755 new mode 100644 index c4e5e9121b..64983458df --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "ir/func_graph.h" #include "backend/session/session_basic.h" #include "backend/session/anf_runtime_algorithm.h" @@ -116,6 +117,7 @@ class ParameterServer { bool ReadyForUpdateWeights(); bool ReadyForAccumGrads(); void ResetGradAccumCount(); + const CNodePtr GetCNode(const std::string &name) const; size_t pserver_num_; size_t worker_num_; @@ -132,6 +134,7 @@ class ParameterServer { std::unordered_map> optim_infos_; std::unordered_map> optim_info_builders_; std::unordered_map weight_key_to_optims_; + std::unordered_map weight_key_to_optim_op_; std::unordered_map weights_; std::unordered_map grads_; std::unordered_map grads_accum_counter_; @@ -277,7 +280,6 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) { handler_->Init(); InitOptimInfoBuilders(); - ps_->set_request_handle(*handler_); thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); return true; @@ -299,6 +301,7 @@ void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_ return; } weight_key_to_optims_[key] = Util::optimizer_name(optim_id); + weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id); } template @@ -321,27 +324,42 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &va } if (weight_key_to_optims_.count(key) > 0) { const std::string &optim_name = weight_key_to_optims_[key]; + const std::string &optim_op_name = weight_key_to_optim_op_[key]; if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) { + const CNodePtr cnode = GetCNode(optim_op_name); + MS_EXCEPTION_IF_NULL(cnode); if (optim_name == kSparseAdam) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } else if (optim_name == kApplyMomentum) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } else if (optim_name == kSparseFtrl) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } } } } +template +const CNodePtr ParameterServer::GetCNode(const std::string &name) const { + std::list cnodes = func_graph_->GetOrderedCnodes(); + for (CNodePtr cnode : cnodes) { + std::string fullname = cnode->fullname_with_scope(); + if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) { + return cnode; + } + } + return nullptr; +} + template void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { MS_LOG(INFO) << "Initializing weight for key " << key; diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc index fc63e88901..f7ecbb8a52 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -33,6 +33,13 @@ std::unordered_map Util::id_to_optimizers{ {1, kSparseAdam}, {2, kSparseFtrl}, }; + +std::unordered_map Util::id_to_optimizer_nodes{ + {0, kApplyMomentumOp}, + {1, kSparseAdamOp}, + {2, kSparseFtrlOp}, +}; + bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } bool Util::IsRoleOfWorker() { @@ -112,6 +119,13 @@ std::string Util::optimizer_name(int id) { return ""; } +std::string Util::optimizer_node_name(int id) { + if (id_to_optimizer_nodes.count(id) > 0) { + return id_to_optimizer_nodes[id]; + } + return ""; +} + bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } int Util::LocalShard(int first_dim, int rank_id, int server_num) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h index 2649f452e1..c56ab8264e 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.h +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -34,12 +34,14 @@ class Util { static void SetInternalEnvVar(); static int optimizer_id(std::string name); static std::string optimizer_name(int id); + static std::string optimizer_node_name(int id); static bool is_optimizer(std::string name); static int LocalShard(int first_dim, int rank_id, int server_num); private: static std::unordered_map optimizer_to_ids; static std::unordered_map id_to_optimizers; + static std::unordered_map id_to_optimizer_nodes; }; } // namespace ps } // namespace parallel diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index 1bf49d9ec7..e31e6345d4 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients. """ from .optimizer import Optimizer from .momentum import Momentum -from .adam import Adam, PSAdam, AdamWeightDecay +from .adam import Adam, AdamWeightDecay from .lamb import Lamb from .sgd import SGD from .lars import LARS -from .ftrl import FTRL, PSFTRL +from .ftrl import FTRL from .rmsprop import RMSProp from .proximal_ada_grad import ProximalAdagrad from .lazyadam import LazyAdam -__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam', - 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] +__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', + 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index b823bf69f3..bd0d18347a 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -27,7 +27,6 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer _adam_opt = C.MultitypeFuncGraph("adam_opt") -_adam_push_pull_opt = C.MultitypeFuncGraph("_adam_push_pull_opt") @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", @@ -85,77 +84,42 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d return gradient -@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "IndexedSlices", - "Tensor", "Tensor", "Tensor", "Bool") -def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2, ps_parameter): +@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor", "Bool") +def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, + gradient, params, moment1, moment2, ps_parameter): """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices() values = gradient.values() if ps_parameter: op_shape = P.Shape() - _ps_pull = P.Pull() - _ps_push = P.Push("Adam", [0, 1, 2]) shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) - success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2, - eps, values, indices), shapes), params)) + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices), shapes), params)) else: success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, values, indices)) return success -@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Bool") -def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2, ps_parameter): +@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, + params, moment1, moment2, ps_parameter): """Apply adam optimizer to the weight parameter using Tensor.""" success = True if ps_parameter: op_shape = P.Shape() - _ps_pull = P.Pull() - _ps_push = P.Push("Adam", [0, 1, 2]) - success = F.depend(success, _ps_pull(_ps_push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), - (op_shape(params), op_shape(moment1), op_shape(moment2))), - params)) + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), + (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) else: success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)) return success - -@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor") -def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2): - """Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse.""" - success = True - op_shape = P.Shape() - values = gradient.values() - indices = gradient.indices() - shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), - op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), - op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) - success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, - eps, values, indices), shapes), params)) - return success - - -@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2): - """Apply adam optimizer by push and pull to the weight parameter using Tensor.""" - success = True - op_shape = P.Shape() - success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), - (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) - return success - - def _check_param_value(beta1, beta2, eps, prim_name): """Check the type of inputs.""" validator.check_value_type("beta1", beta1, [float], prim_name) @@ -285,6 +249,10 @@ class Adam(Optimizer): self.opt = P.Adam(use_locking, use_nesterov) self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov) + self._ps_pull = P.Pull() + self._ps_push = P.Push("Adam", [0, 1, 2]) + self._ps_push.add_prim_attr("use_nesterov", use_nesterov) + def construct(self, gradients): params = self.parameters moment1 = self.moment1 @@ -298,63 +266,16 @@ class Adam(Optimizer): beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power if self.is_group_lr: - success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps), + success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps), lr, gradients, params, moment1, moment2, self.ps_parameters) else: - success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps, lr), + success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), gradients, params, moment1, moment2, self.ps_parameters) return success -class PSAdam(Optimizer): - '''The same usage as Adam optimizer except the parameters are set PS mode.''' - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, - use_nesterov=False, weight_decay=0.0, loss_scale=1.0): - super(PSAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) - _check_param_value(beta1, beta2, eps, self.cls_name) - validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) - validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) - - self.beta1 = Tensor(beta1, mstype.float32) - self.beta2 = Tensor(beta2, mstype.float32) - self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") - self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") - self.eps = Tensor(eps, mstype.float32) - - self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') - self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') - - self.hyper_map = C.HyperMap() - self.push = P.Push("Adam", [0, 1, 2]) - self.push.add_prim_attr("primitive_target", "CPU") - self.pull = P.Pull() - self.pull.add_prim_attr("primitive_target", "CPU") - - def construct(self, gradients): - params = self.parameters - moment1 = self.moment1 - moment2 = self.moment2 - gradients = self.decay_weight(gradients) - gradients = self.scale_grad(gradients) - lr = self.get_lr() - - beta1_power = self.beta1_power * self.beta1 - self.beta1_power = beta1_power - beta2_power = self.beta2_power * self.beta2 - self.beta2_power = beta2_power - if self.is_group_lr: - success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps), - lr, gradients, params, moment1, moment2) - else: - success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps, lr), - gradients, params, moment1, moment2) - return success - - class AdamWeightDecay(Optimizer): """ Implements Adam algorithm weight decay fix. diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index 098c382da2..e2be57e820 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -21,68 +21,40 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer, _apply_decay, _grad_scale _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") -_ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") -@_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "IndexedSlices", "Tensor", - "Tensor", "Bool") -def _tensor_run_opt_with_sparse(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment, - ps_parameter): +@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", + "IndexedSlices", "Tensor", "Tensor", "Bool") +def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, + gradient, weight, moment, ps_parameter): """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices() values = gradient.values() if ps_parameter: op_shape = P.Shape() - _ps_pull = P.Pull() - _ps_push = P.Push("Ftrl", [0, 1, 2]) shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) - success = F.depend(success, _ps_pull(_ps_push((values, indices), shapes), weight)) + success = F.depend(success, pull(push((values, indices), shapes), weight)) else: success = F.depend(success, spars_opt(weight, moment, linear, values, indices)) return success -@_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Bool") -def _tensor_run_opt(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment, ps_parameter): +@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Bool") +def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, + gradient, weight, moment, ps_parameter): """Apply ftrl optimizer to the weight parameter.""" success = True if ps_parameter: op_shape = P.Shape() - _ps_pull = P.Pull() - _ps_push = P.Push("Ftrl", [0, 1, 2]) - success = F.depend(success, _ps_pull(_ps_push((gradient, learning_rate, l1, l2, lr_power), - (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) + success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), + (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) else: success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) return success -@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", - "Tensor", "Tensor") -def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, - weight, moment): - success = True - op_shape = P.Shape() - values = gradient.values() - indices = gradient.indices() - shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) - success = F.depend(success, pull(push((values, indices), shapes), weight)) - return success - - -@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", - "Tensor", "Tensor") -def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, - weight, moment): - success = True - op_shape = P.Shape() - success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), - (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) - return success - - def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None): """Check param.""" validator.check_value_type("initial_accum", initial_accum, [float], prim_name) @@ -188,6 +160,12 @@ class FTRL(Optimizer): self.hyper_map = C.HyperMap() self.opt = P.ApplyFtrl(use_locking=use_locking) self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) + self._ps_pull = P.Pull() + self._ps_push = P.Push("Ftrl", [0, 1, 2]) + self._ps_push.add_prim_attr("lr", learning_rate) + self._ps_push.add_prim_attr("l1", l1) + self._ps_push.add_prim_attr("l2", l2) + self._ps_push.add_prim_attr("lr_power", lr_power) def construct(self, grads): params = self.parameters @@ -197,41 +175,7 @@ class FTRL(Optimizer): grads = self.scale_grad(grads) lr = self.get_lr() - success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self.l1, self.l2, self.lr_power, lr), + success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + self.l1, self.l2, self.lr_power, lr), linear, grads, params, moments, self.ps_parameters) return success - - -class PSFTRL(Optimizer): - def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, - use_locking=False, loss_scale=1.0, weight_decay=0.0): - super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name) - self.moments = self.parameters.clone(prefix="moments", init=initial_accum) - self.linear = self.parameters.clone(prefix="linear", init='zeros') - self.l1 = l1 - self.l2 = l2 - self.lr_power = lr_power - self.weight_decay = weight_decay - self.decay_tf = tuple((lambda: True)() for x in self.parameters) - - self.hyper_map = C.HyperMap() - self.push = P.Push("Ftrl", [0, 1, 2]) - self.push.add_prim_attr("primitive_target", "CPU") - self.pull = P.Pull() - self.pull.add_prim_attr("primitive_target", "CPU") - - def construct(self, grads): - params = self.parameters - moments = self.moments - linear = self.linear - lr = self.learning_rate - if self.weight_decay > 0.0: - grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) - - grads = self.scale_grad(grads) - success = self.map_(F.partial(_ftrl_push_pull_opt, self.push, self.pull, lr, self.l1, self.l2, self.lr_power), - linear, grads, params, moments) - return success