From a3e6009b53345b34ed396a36fb098c07b4e967aa Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Wed, 23 Jun 2021 10:24:40 +0800 Subject: [PATCH] Add reverse option for HyperMap and Map, and enable reverse option for optimizers. --- .../frontend/operator/composite/composite.cc | 128 ++++++++++++------ .../frontend/operator/composite/composite.h | 7 +- .../ccsrc/frontend/operator/composite/map.cc | 112 ++++++++++----- .../ccsrc/frontend/operator/composite/map.h | 25 ++-- .../pipeline/jit/static_analysis/evaluator.cc | 33 ++--- .../pipeline/jit/static_analysis/evaluator.h | 17 ++- .../jit/static_analysis/stack_frame.cc | 18 ++- .../jit/static_analysis/stack_frame.h | 16 +-- mindspore/nn/optim/ada_grad.py | 9 +- mindspore/nn/optim/adam.py | 51 ++++--- mindspore/nn/optim/ftrl.py | 7 +- mindspore/nn/optim/lamb.py | 25 ++-- mindspore/nn/optim/lars.py | 1 - mindspore/nn/optim/lazyadam.py | 23 ++-- mindspore/nn/optim/momentum.py | 9 +- mindspore/nn/optim/optimizer.py | 3 + mindspore/nn/optim/proximal_ada_grad.py | 10 +- mindspore/nn/optim/rmsprop.py | 22 +-- mindspore/nn/optim/sgd.py | 7 +- mindspore/nn/optim/thor.py | 2 - mindspore/ops/composite/base.py | 18 +-- 21 files changed, 326 insertions(+), 217 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 9dd07034020..2d657cc2e41 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -123,16 +123,21 @@ void HyperMap::Init() { {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -HyperMap::HyperMap(const std::shared_ptr &fn_leaf) +HyperMap::HyperMap(bool reverse, const std::shared_ptr &fn_leaf) : MetaFuncGraph("hyper_map"), fn_leaf_(fn_leaf), + reverse_(reverse), broadcast_(false), nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { Init(); } HyperMap::HyperMap(const HyperMap &h) - : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + : MetaFuncGraph("hyper_map"), + fn_leaf_(h.fn_leaf_), + reverse_(h.reverse_), + broadcast_(h.broadcast_), + nonleaf_(h.nonleaf_) { Init(); } @@ -156,12 +161,23 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraph MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); - std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { - auto lhs = std::static_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); + size_t size = type->elements().size(); + size_t num = 0; + bool is_not_same = + std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair &item) { + num++; + auto lhs = std::static_pointer_cast(item.second); + if (lhs == nullptr) { + MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a List, but got " + << item.second->ToString(); + } + if (lhs->elements().size() != size) { + MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got " + << lhs->elements().size(); + return true; + } + return false; + }); if (is_not_same) { MS_LOG(EXCEPTION) << "List in HyperMap should have same length"; } @@ -169,24 +185,31 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraph // cannot use shared_from_base() also known as this, as it will make a reference cycle on // hypermap and graph generated, it will cause memory leak. auto fn_rec = NewValueNode(std::make_shared(*this)); + constexpr size_t kPrimHoldLen = 1; std::vector inputs; + inputs.reserve(size + kPrimHoldLen); inputs.push_back(NewValueNode(prim::kPrimMakeList)); - for (int64_t i = 0; i < SizeToLong(size); ++i) { + for (size_t i = 0; i < size; i++) { + MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_; std::vector inputs2; inputs2.push_back(fn_rec); if (fn_arg != nullptr) { inputs2.push_back(fn_arg); } - - (void)std::transform( - arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair &item) { - return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); - }); + size_t pos = (reverse_ ? (size - 1 - i) : i); + (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), + [&func_graph, pos](const std::pair &item) { + return func_graph->NewCNodeInOrder( + {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))}); + }); auto call_node = func_graph->NewCNodeInOrder(inputs2); - inputs.push_back(call_node); + if (reverse_) { + inputs.insert(inputs.begin() + 1, call_node); + } else { + inputs.emplace_back(call_node); + } } return func_graph->NewCNodeInOrder(inputs); } @@ -196,12 +219,23 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); - std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { - auto lhs = std::static_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); + size_t size = type->elements().size(); + size_t num = 0; + bool is_not_same = + std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair &item) { + num++; + auto lhs = std::static_pointer_cast(item.second); + if (lhs == nullptr) { + MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a Tuple, but got " + << item.second->ToString(); + } + if (lhs->elements().size() != size) { + MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got " + << lhs->elements().size(); + return true; + } + return false; + }); if (is_not_same) { MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length"; } @@ -209,23 +243,31 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap // cannot use shared_from_base() also known as this, as it will make a reference cycle on // hypermap and graph generated, it will cause memory leak. auto fn_rec = NewValueNode(std::make_shared(*this)); + constexpr size_t kPrimHoldLen = 1; std::vector inputs; + inputs.reserve(size + kPrimHoldLen); inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (int64_t i = 0; i < SizeToLong(size); ++i) { + for (size_t i = 0; i < size; i++) { + MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_; std::vector inputs2; inputs2.push_back(fn_rec); if (fn_arg != nullptr) { inputs2.push_back(fn_arg); } - - (void)std::transform( - arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { - return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); - }); + size_t pos = (reverse_ ? (size - 1 - i) : i); + (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), + [&func_graph, &pos](std::pair item) { + return func_graph->NewCNodeInOrder( + {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))}); + }); auto call_node = func_graph->NewCNodeInOrder(inputs2); - inputs.push_back(call_node); + if (reverse_) { + inputs.insert(inputs.begin() + 1, call_node); + } else { + inputs.emplace_back(call_node); + } } return func_graph->NewCNodeInOrder(inputs); } @@ -235,29 +277,38 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(func_graph); + std::size_t attrSize = type->GetAttributes().size(); + constexpr size_t kPrimAndTypeLen = 2; std::vector inputs; + inputs.reserve(attrSize + kPrimAndTypeLen); inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); inputs.push_back(NewValueNode(type)); // cannot use shared_from_base() also known as this, as it will make a reference cycle on // hypermap and graph generated, it will cause memory leak. auto fn_rec = NewValueNode(std::make_shared(*this)); - std::size_t attrSize = type->GetAttributes().size(); - for (std::size_t i = 0; i < attrSize; ++i) { + for (std::size_t i = 0; i < attrSize; i++) { + MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_; std::vector inputs2; inputs2.push_back(fn_rec); if (fn_arg) { inputs2.push_back(fn_arg); } - int64_t j = 0; - for (auto item : arg_map) { - inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); - j++; + size_t size = arg_map.size(); + for (size_t j = 0; j < size; j++) { + size_t pos = (reverse_ ? (size - 1 - j) : j); + auto &item = arg_map[pos]; + inputs2.push_back( + func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))})); } auto call_node = func_graph->NewCNodeInOrder(inputs2); - inputs.push_back(call_node); + if (reverse_) { + inputs.insert(inputs.begin() + 2, call_node); + } else { + inputs.emplace_back(call_node); + } } return func_graph->NewCNodeInOrder(inputs); } @@ -383,8 +434,9 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { (void)py::class_>(*m, "HyperMap_") - .def(py::init>(), py::arg("leaf")) - .def(py::init<>()); + .def(py::init>(), py::arg("reverse"), + py::arg("ops")) + .def(py::init(), py::arg("reverse")); })); FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const { diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index 061534bf8a1..28476315874 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -48,12 +48,13 @@ using ArgsPairList = std::vector>; class HyperMap : public MetaFuncGraph { public: - explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); + explicit HyperMap(bool reverse = false, const std::shared_ptr &fn_leaf = nullptr); HyperMap(const HyperMap &h); void Init(); HyperMap &operator=(const HyperMap &h) { if (this != &h) { fn_leaf_ = h.fn_leaf_; + reverse_ = h.reverse_; broadcast_ = h.broadcast_; nonleaf_ = h.nonleaf_; if (fn_leaf_) { @@ -82,6 +83,7 @@ class HyperMap : public MetaFuncGraph { ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); MultitypeFuncGraphPtr fn_leaf_; + bool reverse_; bool broadcast_; std::set nonleaf_; }; @@ -89,7 +91,8 @@ using HyperMapPtr = std::shared_ptr; class HyperMapPy : public HyperMap { public: - explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} + explicit HyperMapPy(bool reverse = false, const std::shared_ptr &fn_leaf = nullptr) + : HyperMap(reverse, fn_leaf) {} ~HyperMapPy() override = default; MS_DECLARE_PARENT(HyperMapPy, HyperMap) }; diff --git a/mindspore/ccsrc/frontend/operator/composite/map.cc b/mindspore/ccsrc/frontend/operator/composite/map.cc index 7595a13708d..ca0766996a9 100644 --- a/mindspore/ccsrc/frontend/operator/composite/map.cc +++ b/mindspore/ccsrc/frontend/operator/composite/map.cc @@ -71,21 +71,33 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphP MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); + size_t num = 0; bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair &item) { + num++; auto lhs = std::dynamic_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; + if (lhs == nullptr) { + MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a List, but got " + << item.second->ToString(); + } + if (lhs->elements().size() != size) { + MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got " + << lhs->elements().size(); + return true; + } + return false; }); if (is_not_same) { MS_LOG(EXCEPTION) << "List in Map should have same length"; } + constexpr size_t kPrimHoldLen = 1; std::vector inputs; + inputs.reserve(size + kPrimHoldLen); inputs.push_back(NewValueNode(prim::kPrimMakeList)); - for (int64_t i = 0; i < SizeToLong(size); ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; + for (size_t i = 0; i < size; i++) { + MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_; auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); auto fn = NewValueNode(ptrGraph); @@ -95,13 +107,19 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphP inputs2.push_back(fn_arg); } - (void)std::transform( - arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair &item) { - return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); - }); + size_t pos = (reverse_ ? (size - 1 - i) : i); + (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, pos](const std::pair &item) { + return func_graph->NewCNodeInOrder( + {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))}); + }); - inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); + auto call_node = func_graph->NewCNodeInOrder(inputs2); + if (reverse_) { + inputs.insert(inputs.begin() + 1, call_node); + } else { + inputs.emplace_back(call_node); + } } return func_graph->NewCNodeInOrder(inputs); } @@ -111,22 +129,34 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGrap MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); - std::size_t size = type->elements().size(); + size_t size = type->elements().size(); + size_t num = 0; bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair &item) { + num++; auto lhs = std::dynamic_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; + if (lhs == nullptr) { + MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a Tuple, but got " + << item.second->ToString(); + } + if (lhs->elements().size() != size) { + MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got " + << lhs->elements().size(); + return true; + } + return false; }); if (is_not_same) { MS_LOG(EXCEPTION) << "tuple in Map should have same length"; } + constexpr size_t kPrimHoldLen = 1; std::vector inputs; + inputs.reserve(size + kPrimHoldLen); inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (int64_t i = 0; i < SizeToLong(size); ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; + for (size_t i = 0; i < size; i++) { + MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_; auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); auto fn = NewValueNode(ptrGraph); @@ -136,13 +166,19 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGrap inputs2.push_back(fn_arg); } - (void)std::transform( - arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), - [&func_graph, &i](std::pair item) { - return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); - }); + size_t pos = (reverse_ ? (size - 1 - i) : i); + (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, &pos](const std::pair &item) { + return func_graph->NewCNodeInOrder( + {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))}); + }); - inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); + auto call_node = func_graph->NewCNodeInOrder(inputs2); + if (reverse_) { + inputs.insert(inputs.begin() + 1, call_node); + } else { + inputs.emplace_back(call_node); + } } return func_graph->NewCNodeInOrder(inputs); } @@ -152,13 +188,15 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGrap MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(func_graph); + size_t attrSize = type->GetAttributes().size(); + constexpr size_t kPrimAndTypeLen = 2; std::vector inputs; + inputs.reserve(attrSize + kPrimAndTypeLen); inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); inputs.push_back(NewValueNode(type)); - std::size_t attrSize = type->GetAttributes().size(); - for (std::size_t i = 0; i < attrSize; ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; + for (size_t i = 0; i < attrSize; i++) { + MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_; auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); auto fn = NewValueNode(ptrGraph); @@ -168,13 +206,20 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGrap inputs2.push_back(fn_arg); } - int64_t j = 0; - for (auto item : arg_pairs) { - inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); - j++; + size_t size = arg_pairs.size(); + for (size_t j = 0; j < size; j++) { + size_t pos = (reverse_ ? (size - 1 - j) : j); + auto &item = arg_pairs[pos]; + inputs2.push_back( + func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))})); } - inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); + auto call_node = func_graph->NewCNodeInOrder(inputs2); + if (reverse_) { + inputs.insert(inputs.begin() + 2, call_node); + } else { + inputs.emplace_back(call_node); + } } return func_graph->NewCNodeInOrder(inputs); } @@ -284,8 +329,9 @@ abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { (void)py::class_>(*m, "Map_") - .def(py::init>(), py::arg("leaf")) - .def(py::init<>()); + .def(py::init>(), py::arg("reverse"), + py::arg("ops")) + .def(py::init(), py::arg("reverse")); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/map.h b/mindspore/ccsrc/frontend/operator/composite/map.h index da8edcef43d..026a936b56c 100644 --- a/mindspore/ccsrc/frontend/operator/composite/map.h +++ b/mindspore/ccsrc/frontend/operator/composite/map.h @@ -33,21 +33,28 @@ using ArgsPairList = std::vector>; class Map : public MetaFuncGraph { public: - explicit Map(const std::shared_ptr &fn_leaf = nullptr) + explicit Map(bool reverse = false, const std::shared_ptr &fn_leaf = nullptr) : MetaFuncGraph("map"), fn_leaf_(fn_leaf), + reverse_(reverse), broadcast_(false), nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { Init(); } - Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + Map(const Map &map) + : MetaFuncGraph("map"), + fn_leaf_(map.fn_leaf_), + reverse_(map.reverse_), + broadcast_(map.broadcast_), + nonleaf_(map.nonleaf_) { Init(); } - Map &operator=(const Map &h) { - if (this != &h) { - fn_leaf_ = h.fn_leaf_; - broadcast_ = h.broadcast_; - nonleaf_ = h.nonleaf_; + Map &operator=(const Map &map) { + if (this != &map) { + fn_leaf_ = map.fn_leaf_; + reverse_ = map.reverse_; + broadcast_ = map.broadcast_; + nonleaf_ = map.nonleaf_; if (fn_leaf_) { name_ = "map[" + fn_leaf_->name() + "]"; } @@ -81,13 +88,15 @@ class Map : public MetaFuncGraph { } MultitypeFuncGraphPtr fn_leaf_; + bool reverse_; bool broadcast_; std::set nonleaf_; }; using MapPtr = std::shared_ptr; class MapPy : public Map { public: - explicit MapPy(const std::shared_ptr &fn_leaf = nullptr) : Map(fn_leaf) {} + explicit MapPy(bool reverse = false, const std::shared_ptr &fn_leaf = nullptr) + : Map(reverse, fn_leaf) {} ~MapPy() override = default; MS_DECLARE_PARENT(MapPy, Map) }; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 9c793970242..7a72d5b7de8 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -127,22 +127,13 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr << ", res_base: " << res_base->ToString(); break; } - - // Overwrite the result if func graph is stub. - if (current_stack_frame->func_graph()->stub()) { - eval_result = std::make_shared(std::make_shared(), nullptr); - } - // Save func graph eval result for specialize. - auto evaluator = current_stack_frame->evaluator(); - MS_EXCEPTION_IF_NULL(evaluator); - evaluator->evaluator_cache_mgr()->SetValue(current_stack_frame->args_abs_list(), eval_result); - // Leave current func graph. LeaveStackFrame(engine, current_stack_frame); // Switch the stack frame. + auto last_stack_frame = current_stack_frame; current_stack_frame = stack_frames.top(); MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame; - current_stack_frame->Back(engine, eval_result); + current_stack_frame->Back(engine, last_stack_frame, eval_result); continue; } @@ -223,7 +214,6 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":" << fg->ToString() << "();"; } - context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list); const auto ¶meters = fg->parameters(); for (size_t i = 0; i < nargs; i++) { @@ -399,8 +389,7 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { - const string evaluator_name = ToString(); - std::unique_lock eval_lock(eval_loc_, std::try_to_lock); + std::unique_lock eval_lock(eval_lock_, std::try_to_lock); if (!eval_lock.owns_lock()) { auto py_tstate = PyEval_SaveThread(); eval_lock.try_lock_for(std::chrono::seconds(kInferTimeout)); @@ -420,26 +409,26 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args args_spec_list = BroadenUndeterminedArgs(args_spec_list); trace::TraceGraphEvalEnter(shared_from_base(), out_conf); MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base(), args_spec_list, out_conf); + const std::string &evaluator_name = ToString(); MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_); auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list); if (eval_result == nullptr) { MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; - EvalResultPtr res = Eval(engine, args_spec_list); - if (res->abstract() == nullptr) { + eval_result = Eval(engine, args_spec_list); + MS_EXCEPTION_IF_NULL(eval_result); + if (eval_result->abstract() == nullptr) { EvalFailLogging(shared_from_base(), args_spec_list, out_conf); MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; } - MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << res->abstract()->ToString() << "."; - evaluator_cache_mgr_->SetValue(args_spec_list, res); - trace::TraceGraphEvalLeave(shared_from_base()); - return res; + MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << eval_result->abstract()->ToString() << "."; + evaluator_cache_mgr_->SetValue(args_spec_list, eval_result); } else { MS_EXCEPTION_IF_NULL(eval_result); MS_EXCEPTION_IF_NULL(eval_result->abstract()); MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << "."; - trace::TraceGraphEvalLeave(shared_from_base()); - return eval_result; } + trace::TraceGraphEvalLeave(shared_from_base()); + return eval_result; } EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h index 78b85192a69..9e4037cedf3 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -40,9 +40,9 @@ using EvaluatorAttrCachePtr = std::shared_ptr; class Evaluator : public Base { public: explicit Evaluator(const std::string &id) - : evaluator_cache_mgr_(std::make_shared()), - attr_cache_(std::make_shared()), - identifier_(id) {} + : identifier_(id), + evaluator_cache_mgr_(std::make_shared()), + attr_cache_(std::make_shared()) {} ~Evaluator() override = default; MS_DECLARE_PARENT(Evaluator, Base); @@ -92,13 +92,16 @@ class Evaluator : public Base { EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; } EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; } - EvaluatorCacheMgrPtr evaluator_cache_mgr_; - EvaluatorAttrCachePtr attr_cache_; + std::recursive_timed_mutex &eval_lock() { return eval_lock_; } + protected: std::string identifier_; - AnfNodeWeakPtr bound_node_; - std::recursive_timed_mutex eval_loc_; + EvaluatorCacheMgrPtr evaluator_cache_mgr_; + std::recursive_timed_mutex eval_lock_; + + private: + EvaluatorAttrCachePtr attr_cache_; }; class PrimEvaluator : public Evaluator { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc index 5cb9631f73c..829322bbee5 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc @@ -65,7 +65,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr // Evaluate the inputs firstly. Build arguments for the func graph. AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode); - // Check if already evaluated before. if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) { return nullptr; @@ -80,7 +79,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr << fg->parameters().size() << ", but the number of provided arguments is " << args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); } - MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString() << ", current_context_: " << current_context_->ToString(); @@ -133,8 +131,20 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) { return node_eval_result; } -// Return back from branch func graph. -void StackFrame::Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result) { +// Return back from child func graph. +void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, + const EvalResultPtr &eval_result) { + // Overwrite the result if func graph is stub. + EvalResultPtr result = eval_result; + if (last_stack_frame->func_graph()->stub()) { + result = std::make_shared(std::make_shared(), nullptr); + } + // Save func graph eval result for specialize. + auto evaluator = last_stack_frame->evaluator(); + MS_EXCEPTION_IF_NULL(evaluator); + evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result); + + // Continue saving node's result for parent func graph. auto ¤t_node = NextNode(); MS_LOG(DEBUG) << "current_node: " << current_node->DebugString() << ", current_context_: " << current_context_->ToString(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.h b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.h index a332b9d034d..77fb3362149 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.h @@ -46,7 +46,7 @@ class StackFrame : public Base { virtual ~StackFrame() = default; void Load() { - node_slots = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType { + node_slots_ = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType { if (node->isa() || node->isa()) { return EXCLUDE; } @@ -61,17 +61,17 @@ class StackFrame : public Base { // Run one step in current func graph. EvalResultPtr Step(const AnalysisEnginePtr &engine); // Return back from branch func graph. - void Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result); + void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result); bool Done() { return done_; } AnfNodePtr &CurrentNode() { - if (slot_index_ >= node_slots.size()) { + if (slot_index_ >= node_slots_.size()) { MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToString() << " is invalid. Try to access frame sequence by index " << slot_index_ - << ", while the size is " << node_slots.size() << "."; + << ", while the size is " << node_slots_.size() << "."; } - return node_slots[slot_index_]; + return node_slots_[slot_index_]; } AnfNodePtr &NextNode() { @@ -97,8 +97,8 @@ class StackFrame : public Base { MS_EXCEPTION_IF_NULL(func_graph_); std::ostringstream buffer; buffer << "StackFrame: " << this << ", " << func_graph_->ToString(); - if (slot_index_ < node_slots.size()) { - auto current_node = node_slots[slot_index_]; + if (slot_index_ < node_slots_.size()) { + auto current_node = node_slots_[slot_index_]; buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")"; } else { buffer << "(Exhausted..)"; @@ -132,7 +132,7 @@ class StackFrame : public Base { AnalysisContextPtr current_context_; AnalysisContextPtr parent_context_; AbstractBasePtrList args_abs_list_; - std::vector node_slots; + std::vector node_slots_; size_t slot_index_; bool done_; }; diff --git a/mindspore/nn/optim/ada_grad.py b/mindspore/nn/optim/ada_grad.py index a12f470bd3f..d68be663549 100644 --- a/mindspore/nn/optim/ada_grad.py +++ b/mindspore/nn/optim/ada_grad.py @@ -149,7 +149,6 @@ class Adagrad(Optimizer): super(Adagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) _check_param_value(accum, update_slots, self.cls_name) self.accum = self.parameters.clone(prefix="accum", init=accum) - self.hyper_map = C.HyperMap() self.update_slots = update_slots self.opt = P.ApplyAdagrad(update_slots=update_slots) @@ -161,9 +160,9 @@ class Adagrad(Optimizer): grads = self.scale_grad(grads) lr = self.get_lr() if self.is_group_lr: - success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, - grads) + success = self.map_reverse(F.partial(_ada_grad_opt, self.opt), lr, params, accum, + grads) else: - success = self.map_(F.partial(_ada_grad_opt, self.opt, lr), params, accum, - grads) + success = self.map_reverse(F.partial(_ada_grad_opt, self.opt, lr), params, accum, + grads) return success diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index b45599fef7a..1699e01a73e 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -329,7 +329,6 @@ class Adam(Optimizer): self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') self._is_device = True - self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov) self.sparse_opt.add_prim_attr("primitive_target", "CPU") @@ -352,15 +351,15 @@ class Adam(Optimizer): beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power if self.is_group_lr: - success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, - self.use_locking, self.use_nesterov, self._is_device, - beta1_power, beta2_power, self.beta1, self.beta2, self.eps), - lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) + success = self.map_reverse(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + self.use_locking, self.use_nesterov, self._is_device, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps), + lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) else: - success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, - self.use_locking, self.use_nesterov, self._is_device, - beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), - gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) + success = self.map_reverse(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + self.use_locking, self.use_nesterov, self._is_device, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), + gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) return success @Optimizer.target.setter @@ -460,23 +459,23 @@ class AdamWeightDecay(Optimizer): self.eps = Tensor(np.array([eps]).astype(np.float32)) self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') - self.hyper_map = C.HyperMap() def construct(self, gradients): lr = self.get_lr() if self.is_group: if self.is_group_lr: - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), - lr, self.weight_decay, self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, + self.moments2, gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), - self.weight_decay, self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), - self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) if self.use_parallel: self.broadcast_params(optim_result) return optim_result @@ -614,8 +613,6 @@ class AdamOffload(Optimizer): self.use_locking = use_locking self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') - - self.hyper_map = C.HyperMap() self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov) self.opt.add_prim_attr("primitive_target", "CPU") @@ -632,11 +629,11 @@ class AdamOffload(Optimizer): beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power if self.is_group_lr: - success = self.map_(F.partial(_adam_opt, self.opt, - beta1_power, beta2_power, self.beta1, self.beta2, self.eps), - lr, gradients, params, moment1, moment2) + success = self.map_reverse(F.partial(_adam_opt, self.opt, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps), + lr, gradients, params, moment1, moment2) else: - success = self.map_(F.partial(_adam_opt, self.opt, - beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), - gradients, params, moment1, moment2) + success = self.map_reverse(F.partial(_adam_opt, self.opt, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), + gradients, params, moment1, moment2) return success diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index 0faa10628e0..27b5c9180a3 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -205,7 +205,6 @@ class FTRL(Optimizer): self.lr_power = lr_power if not self.is_group: self.decay_flags = tuple((lambda: True)() for x in self.parameters) - self.hyper_map = C.HyperMap() self.opt = P.ApplyFtrl(use_locking=use_locking) self.use_locking = use_locking self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) @@ -227,9 +226,9 @@ class FTRL(Optimizer): grads = self._grad_sparse_indices_deduplicate(grads) lr = self.get_lr() - success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, - self.l1, self.l2, self.lr_power, lr), - linear, grads, params, moments, self.ps_parameters, self.cache_enable) + success = self.map_reverse(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + self.l1, self.l2, self.lr_power, lr), + linear, grads, params, moments, self.ps_parameters, self.cache_enable) return success @Optimizer.target.setter diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 02d43036e16..7277cdb91ba 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -281,7 +281,6 @@ class Lamb(Optimizer): if not self.dynamic_lr: self.global_step = Parameter(initializer(0, [1]), name='global_step') self.assignadd = P.AssignAdd() - self.hyper_map = C.HyperMap() self.device_ascend = context.get_context("device_target") == "Ascend" def construct(self, gradients): @@ -290,20 +289,20 @@ class Lamb(Optimizer): gradients = self.gradients_centralization(gradients) if self.is_group: if self.is_group_lr: - optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, - self.global_step), - lr, self.weight_decay, self.params, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step), + lr, self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, - self.global_step, lr), - self.weight_decay, self.params, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr), + self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, - self.global_step, lr, self.weight_decay), - self.params, self.moments1, self.moments2, gradients, - self.decay_flags, self.optim_filter) + optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr, self.weight_decay), + self.params, self.moments1, self.moments2, gradients, + self.decay_flags, self.optim_filter) if self.use_parallel: optim_result = F.depend(optim_result, self.broadcast_params(optim_result)) diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 2b67303e77b..2f64a0650fd 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -115,7 +115,6 @@ class LARS(Optimizer): self.decay_flags = optimizer.decay_flags self.reciprocal_scale = optimizer.reciprocal_scale self.need_scale = optimizer.need_scale - self.hyper_map = C.HyperMap() self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index c45b31e2a8a..45db2463d5f 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -248,8 +248,6 @@ class LazyAdam(Optimizer): self._is_device = True self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') - - self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov) self.sparse_opt.add_prim_attr("primitive_target", "CPU") @@ -268,17 +266,18 @@ class LazyAdam(Optimizer): self.beta2_power = self.beta2_power * self.beta2 if self.is_group_lr: - success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, - self.use_locking, self.use_nesterov, self._is_device, - self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), - lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, - self.cache_enable) + success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, + self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, + self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), + lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, + self.cache_enable) else: - success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, - self.use_locking, self.use_nesterov, self._is_device, - self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr), - gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, - self.cache_enable) + success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, + self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, + self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, + lr), + gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, + self.cache_enable) return success @Optimizer.target.setter diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 726b4f32bab..aacd32b9bec 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -157,7 +157,6 @@ class Momentum(Optimizer): self.params = self.parameters self.use_nesterov = Validator.check_bool(use_nesterov) self.moments = self.params.clone(prefix="moments", init='zeros') - self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) def construct(self, gradients): @@ -168,9 +167,9 @@ class Momentum(Optimizer): gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: - success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, - self.ps_parameters, self.cache_enable) + success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum), + lr, gradients, params, moments, self.ps_parameters, self.cache_enable) else: - success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, - self.ps_parameters, self.cache_enable) + success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum, lr), + gradients, params, moments, self.ps_parameters, self.cache_enable) return success diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 37577cf841c..fb9385af8a4 100644 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -197,6 +197,9 @@ class Optimizer(Cell): self.global_step_increase_tensor = Tensor(1, mstype.int32) self.param_length = len(self.parameters) self.map_ = C.Map() + self.map_reverse = C.Map(None, True) + self.hyper_map = C.HyperMap() + self.hyper_map_reverse = C.HyperMap(None, True) self._use_parallel_optimizer() def _use_parallel_optimizer(self): diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 7202bdc2eb5..c4fe9c21d3e 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -160,7 +160,6 @@ class ProximalAdagrad(Optimizer): self.accum = self.parameters.clone(prefix="accum", init=accum) self.l1 = Tensor(l1, mstype.float32) self.l2 = Tensor(l2, mstype.float32) - self.hyper_map = C.HyperMap() self.use_locking = use_locking self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking) @@ -174,11 +173,12 @@ class ProximalAdagrad(Optimizer): grads = self._grad_sparse_indices_deduplicate(grads) lr = self.get_lr() if self.is_group_lr: - success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, - grads, params, accum) + success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), + lr, grads, params, accum) else: - success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr), - grads, params, accum) + success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, + lr), + grads, params, accum) return success @Optimizer.target.setter diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index e0f141c45d3..0f03a811165 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -197,7 +197,6 @@ class RMSProp(Optimizer): self.momentum = momentum self.ms = self.parameters.clone(prefix="mean_square", init='ones') self.moment = self.parameters.clone(prefix="moment", init='zeros') - self.hyper_map = C.HyperMap() self.epsilon = epsilon self.decay = decay @@ -209,17 +208,20 @@ class RMSProp(Optimizer): lr = self.get_lr() if self.centered: if self.is_group_lr: - success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon, - self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) + success = self.hyper_map_reverse(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum), + lr, params, self.mg, self.ms, self.moment, gradients) else: - success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon, - self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) - + success = self.hyper_map_reverse(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum, lr), + params, self.mg, self.ms, self.moment, gradients) else: if self.is_group_lr: - success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon, - self.momentum), lr, params, self.ms, self.moment, gradients) + success = self.hyper_map_reverse(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum), + lr, params, self.ms, self.moment, gradients) else: - success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon, - self.momentum, lr), params, self.ms, self.moment, gradients) + success = self.hyper_map_reverse(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum, lr), + params, self.ms, self.moment, gradients) return success diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 578e999e8a5..46beb91920d 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -170,7 +170,6 @@ class SGD(Optimizer): self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.accum = self.parameters.clone(prefix="accum", init='zeros') self.stat = self.parameters.clone(prefix="stat", init='ones') - self.hyper_map = C.HyperMap() def construct(self, gradients): params = self.parameters @@ -180,7 +179,9 @@ class SGD(Optimizer): gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: - success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) + success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum), + lr, gradients, params, accum, stat) else: - success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) + success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum, lr), + gradients, params, accum, stat) return success diff --git a/mindspore/nn/optim/thor.py b/mindspore/nn/optim/thor.py index 67091f618c3..6a90106becf 100644 --- a/mindspore/nn/optim/thor.py +++ b/mindspore/nn/optim/thor.py @@ -318,7 +318,6 @@ class ThorGpu(Optimizer): self.params = self.parameters self.use_nesterov = Validator.check_bool(use_nesterov) self.moments = self.params.clone(prefix="moments", init='zeros') - self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) self.net = net self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) @@ -606,7 +605,6 @@ class ThorAscend(Optimizer): self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.moments = self.params.clone(prefix="moments", init='zeros') - self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum() self.net = net self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index f494d3fdf15..ba53f7d0b5d 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -483,6 +483,7 @@ class HyperMap(HyperMap_): Args: ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, the operations should be put in the first input of the instance. + reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode. Inputs: - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length. @@ -517,20 +518,20 @@ class HyperMap(HyperMap_): >>> print(output) ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)), (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16))) - >>> square_map = HyperMap(square) + >>> square_map = HyperMap(square, False) >>> output = square_map(nest_tensor_list) >>> print(output) ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)), (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16))) """ - def __init__(self, ops=None): + def __init__(self, ops=None, reverse=False): """Initialize HyperMap.""" self.ops = ops if ops: - HyperMap_.__init__(self, ops) + HyperMap_.__init__(self, reverse, ops) else: - HyperMap_.__init__(self) + HyperMap_.__init__(self, reverse) def __call__(self, *args): func = self.ops @@ -555,6 +556,7 @@ class Map(Map_): Args: ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, the operations should be put in the first input of the instance. Default: None + reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode. Inputs: - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, @@ -581,20 +583,20 @@ class Map(Map_): >>> print(output) (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 9)) - >>> square_map = Map(square) + >>> square_map = Map(square, False) >>> output = square_map(tensor_list) >>> print(output) (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 9)) """ - def __init__(self, ops=None): + def __init__(self, ops=None, reverse=False): """Initialize Map.""" self.ops = ops if ops: - Map_.__init__(self, ops) + Map_.__init__(self, reverse, ops) else: - Map_.__init__(self) + Map_.__init__(self, reverse) def __call__(self, *args): func = self.ops