!18751 Add reverse option for HyperMap and Map, and enable reverse option for optimizers.

Merge pull request !18751 from 张清华/opt
This commit is contained in:
i-robot 2021-06-24 12:11:53 +00:00 committed by Gitee
commit 6bc1998a94
21 changed files with 326 additions and 217 deletions

View File

@ -123,16 +123,21 @@ void HyperMap::Init() {
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
}
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
: MetaFuncGraph("hyper_map"),
fn_leaf_(fn_leaf),
reverse_(reverse),
broadcast_(false),
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
Init();
}
HyperMap::HyperMap(const HyperMap &h)
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
: MetaFuncGraph("hyper_map"),
fn_leaf_(h.fn_leaf_),
reverse_(h.reverse_),
broadcast_(h.broadcast_),
nonleaf_(h.nonleaf_) {
Init();
}
@ -156,12 +161,23 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::static_pointer_cast<List>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
});
size_t size = type->elements().size();
size_t num = 0;
bool is_not_same =
std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
num++;
auto lhs = std::static_pointer_cast<List>(item.second);
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a List, but got "
<< item.second->ToString();
}
if (lhs->elements().size() != size) {
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
<< lhs->elements().size();
return true;
}
return false;
});
if (is_not_same) {
MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
}
@ -169,24 +185,31 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
// hypermap and graph generated, it will cause memory leak.
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
constexpr size_t kPrimHoldLen = 1;
std::vector<AnfNodePtr> inputs;
inputs.reserve(size + kPrimHoldLen);
inputs.push_back(NewValueNode(prim::kPrimMakeList));
for (int64_t i = 0; i < SizeToLong(size); ++i) {
for (size_t i = 0; i < size; i++) {
MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn_rec);
if (fn_arg != nullptr) {
inputs2.push_back(fn_arg);
}
(void)std::transform(
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
});
size_t pos = (reverse_ ? (size - 1 - i) : i);
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
[&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNodeInOrder(
{NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
});
auto call_node = func_graph->NewCNodeInOrder(inputs2);
inputs.push_back(call_node);
if (reverse_) {
inputs.insert(inputs.begin() + 1, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -196,12 +219,23 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::static_pointer_cast<Tuple>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
});
size_t size = type->elements().size();
size_t num = 0;
bool is_not_same =
std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
num++;
auto lhs = std::static_pointer_cast<Tuple>(item.second);
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a Tuple, but got "
<< item.second->ToString();
}
if (lhs->elements().size() != size) {
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
<< lhs->elements().size();
return true;
}
return false;
});
if (is_not_same) {
MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
}
@ -209,23 +243,31 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
// hypermap and graph generated, it will cause memory leak.
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
constexpr size_t kPrimHoldLen = 1;
std::vector<AnfNodePtr> inputs;
inputs.reserve(size + kPrimHoldLen);
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (int64_t i = 0; i < SizeToLong(size); ++i) {
for (size_t i = 0; i < size; i++) {
MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn_rec);
if (fn_arg != nullptr) {
inputs2.push_back(fn_arg);
}
(void)std::transform(
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
});
size_t pos = (reverse_ ? (size - 1 - i) : i);
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
[&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
return func_graph->NewCNodeInOrder(
{NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
});
auto call_node = func_graph->NewCNodeInOrder(inputs2);
inputs.push_back(call_node);
if (reverse_) {
inputs.insert(inputs.begin() + 1, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -235,29 +277,38 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGrap
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph);
std::size_t attrSize = type->GetAttributes().size();
constexpr size_t kPrimAndTypeLen = 2;
std::vector<AnfNodePtr> inputs;
inputs.reserve(attrSize + kPrimAndTypeLen);
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
inputs.push_back(NewValueNode(type));
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
// hypermap and graph generated, it will cause memory leak.
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
std::size_t attrSize = type->GetAttributes().size();
for (std::size_t i = 0; i < attrSize; ++i) {
for (std::size_t i = 0; i < attrSize; i++) {
MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_;
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn_rec);
if (fn_arg) {
inputs2.push_back(fn_arg);
}
int64_t j = 0;
for (auto item : arg_map) {
inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
j++;
size_t size = arg_map.size();
for (size_t j = 0; j < size; j++) {
size_t pos = (reverse_ ? (size - 1 - j) : j);
auto &item = arg_map[pos];
inputs2.push_back(
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
}
auto call_node = func_graph->NewCNodeInOrder(inputs2);
inputs.push_back(call_node);
if (reverse_) {
inputs.insert(inputs.begin() + 2, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -383,8 +434,9 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
(void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
.def(py::init<>());
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
py::arg("ops"))
.def(py::init<bool>(), py::arg("reverse"));
}));
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {

View File

@ -48,12 +48,13 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
class HyperMap : public MetaFuncGraph {
public:
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
HyperMap(const HyperMap &h);
void Init();
HyperMap &operator=(const HyperMap &h) {
if (this != &h) {
fn_leaf_ = h.fn_leaf_;
reverse_ = h.reverse_;
broadcast_ = h.broadcast_;
nonleaf_ = h.nonleaf_;
if (fn_leaf_) {
@ -82,6 +83,7 @@ class HyperMap : public MetaFuncGraph {
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
MultitypeFuncGraphPtr fn_leaf_;
bool reverse_;
bool broadcast_;
std::set<TypeId> nonleaf_;
};
@ -89,7 +91,8 @@ using HyperMapPtr = std::shared_ptr<HyperMap>;
class HyperMapPy : public HyperMap {
public:
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
explicit HyperMapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
: HyperMap(reverse, fn_leaf) {}
~HyperMapPy() override = default;
MS_DECLARE_PARENT(HyperMapPy, HyperMap)
};

View File

@ -71,21 +71,33 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
size_t num = 0;
bool is_not_same =
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
num++;
auto lhs = std::dynamic_pointer_cast<List>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a List, but got "
<< item.second->ToString();
}
if (lhs->elements().size() != size) {
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
<< lhs->elements().size();
return true;
}
return false;
});
if (is_not_same) {
MS_LOG(EXCEPTION) << "List in Map should have same length";
}
constexpr size_t kPrimHoldLen = 1;
std::vector<AnfNodePtr> inputs;
inputs.reserve(size + kPrimHoldLen);
inputs.push_back(NewValueNode(prim::kPrimMakeList));
for (int64_t i = 0; i < SizeToLong(size); ++i) {
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target";
for (size_t i = 0; i < size; i++) {
MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_;
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
@ -95,13 +107,19 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
inputs2.push_back(fn_arg);
}
(void)std::transform(
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
});
size_t pos = (reverse_ ? (size - 1 - i) : i);
(void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
[&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNodeInOrder(
{NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
});
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
auto call_node = func_graph->NewCNodeInOrder(inputs2);
if (reverse_) {
inputs.insert(inputs.begin() + 1, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -111,22 +129,34 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
size_t size = type->elements().size();
size_t num = 0;
bool is_not_same =
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
num++;
auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a Tuple, but got "
<< item.second->ToString();
}
if (lhs->elements().size() != size) {
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
<< lhs->elements().size();
return true;
}
return false;
});
if (is_not_same) {
MS_LOG(EXCEPTION) << "tuple in Map should have same length";
}
constexpr size_t kPrimHoldLen = 1;
std::vector<AnfNodePtr> inputs;
inputs.reserve(size + kPrimHoldLen);
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (int64_t i = 0; i < SizeToLong(size); ++i) {
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs";
for (size_t i = 0; i < size; i++) {
MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_;
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
@ -136,13 +166,19 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
inputs2.push_back(fn_arg);
}
(void)std::transform(
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
[&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
});
size_t pos = (reverse_ ? (size - 1 - i) : i);
(void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
[&func_graph, &pos](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNodeInOrder(
{NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
});
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
auto call_node = func_graph->NewCNodeInOrder(inputs2);
if (reverse_) {
inputs.insert(inputs.begin() + 1, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -152,13 +188,15 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph);
size_t attrSize = type->GetAttributes().size();
constexpr size_t kPrimAndTypeLen = 2;
std::vector<AnfNodePtr> inputs;
inputs.reserve(attrSize + kPrimAndTypeLen);
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
inputs.push_back(NewValueNode(type));
std::size_t attrSize = type->GetAttributes().size();
for (std::size_t i = 0; i < attrSize; ++i) {
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs";
for (size_t i = 0; i < attrSize; i++) {
MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_;
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
@ -168,13 +206,20 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
inputs2.push_back(fn_arg);
}
int64_t j = 0;
for (auto item : arg_pairs) {
inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
j++;
size_t size = arg_pairs.size();
for (size_t j = 0; j < size; j++) {
size_t pos = (reverse_ ? (size - 1 - j) : j);
auto &item = arg_pairs[pos];
inputs2.push_back(
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
}
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
auto call_node = func_graph->NewCNodeInOrder(inputs2);
if (reverse_) {
inputs.insert(inputs.begin() + 2, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -284,8 +329,9 @@ abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args
REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
.def(py::init<>());
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
py::arg("ops"))
.def(py::init<bool>(), py::arg("reverse"));
}));
} // namespace prim
} // namespace mindspore

View File

@ -33,21 +33,28 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
class Map : public MetaFuncGraph {
public:
explicit Map(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
explicit Map(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
: MetaFuncGraph("map"),
fn_leaf_(fn_leaf),
reverse_(reverse),
broadcast_(false),
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
Init();
}
Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
Map(const Map &map)
: MetaFuncGraph("map"),
fn_leaf_(map.fn_leaf_),
reverse_(map.reverse_),
broadcast_(map.broadcast_),
nonleaf_(map.nonleaf_) {
Init();
}
Map &operator=(const Map &h) {
if (this != &h) {
fn_leaf_ = h.fn_leaf_;
broadcast_ = h.broadcast_;
nonleaf_ = h.nonleaf_;
Map &operator=(const Map &map) {
if (this != &map) {
fn_leaf_ = map.fn_leaf_;
reverse_ = map.reverse_;
broadcast_ = map.broadcast_;
nonleaf_ = map.nonleaf_;
if (fn_leaf_) {
name_ = "map[" + fn_leaf_->name() + "]";
}
@ -81,13 +88,15 @@ class Map : public MetaFuncGraph {
}
MultitypeFuncGraphPtr fn_leaf_;
bool reverse_;
bool broadcast_;
std::set<TypeId> nonleaf_;
};
using MapPtr = std::shared_ptr<Map>;
class MapPy : public Map {
public:
explicit MapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : Map(fn_leaf) {}
explicit MapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
: Map(reverse, fn_leaf) {}
~MapPy() override = default;
MS_DECLARE_PARENT(MapPy, Map)
};

View File

@ -127,22 +127,13 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
<< ", res_base: " << res_base->ToString();
break;
}
// Overwrite the result if func graph is stub.
if (current_stack_frame->func_graph()->stub()) {
eval_result = std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
}
// Save func graph eval result for specialize.
auto evaluator = current_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
evaluator->evaluator_cache_mgr()->SetValue(current_stack_frame->args_abs_list(), eval_result);
// Leave current func graph.
LeaveStackFrame(engine, current_stack_frame);
// Switch the stack frame.
auto last_stack_frame = current_stack_frame;
current_stack_frame = stack_frames.top();
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
current_stack_frame->Back(engine, eval_result);
current_stack_frame->Back(engine, last_stack_frame, eval_result);
continue;
}
@ -223,7 +214,6 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
<< fg->ToString() << "();";
}
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
const auto &parameters = fg->parameters();
for (size_t i = 0; i < nargs; i++) {
@ -399,8 +389,7 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
const string evaluator_name = ToString();
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_loc_, std::try_to_lock);
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_lock_, std::try_to_lock);
if (!eval_lock.owns_lock()) {
auto py_tstate = PyEval_SaveThread();
eval_lock.try_lock_for(std::chrono::seconds(kInferTimeout));
@ -420,26 +409,26 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
const std::string &evaluator_name = ToString();
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result == nullptr) {
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
EvalResultPtr res = Eval(engine, args_spec_list);
if (res->abstract() == nullptr) {
eval_result = Eval(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(eval_result);
if (eval_result->abstract() == nullptr) {
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
}
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << res->abstract()->ToString() << ".";
evaluator_cache_mgr_->SetValue(args_spec_list, res);
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return res;
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << eval_result->abstract()->ToString() << ".";
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
} else {
MS_EXCEPTION_IF_NULL(eval_result);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << ".";
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return eval_result;
}
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return eval_result;
}
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,

View File

@ -40,9 +40,9 @@ using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>;
class Evaluator : public Base {
public:
explicit Evaluator(const std::string &id)
: evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
attr_cache_(std::make_shared<EvaluatorAttrCache>()),
identifier_(id) {}
: identifier_(id),
evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
attr_cache_(std::make_shared<EvaluatorAttrCache>()) {}
~Evaluator() override = default;
MS_DECLARE_PARENT(Evaluator, Base);
@ -92,13 +92,16 @@ class Evaluator : public Base {
EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
EvaluatorCacheMgrPtr evaluator_cache_mgr_;
EvaluatorAttrCachePtr attr_cache_;
std::recursive_timed_mutex &eval_lock() { return eval_lock_; }
protected:
std::string identifier_;
AnfNodeWeakPtr bound_node_;
std::recursive_timed_mutex eval_loc_;
EvaluatorCacheMgrPtr evaluator_cache_mgr_;
std::recursive_timed_mutex eval_lock_;
private:
EvaluatorAttrCachePtr attr_cache_;
};
class PrimEvaluator : public Evaluator {

View File

@ -65,7 +65,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
// Evaluate the inputs firstly. Build arguments for the func graph.
AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
// Check if already evaluated before.
if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) {
return nullptr;
@ -80,7 +79,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
<< fg->parameters().size() << ", but the number of provided arguments is "
<< args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
}
MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString()
<< ", current_context_: " << current_context_->ToString();
@ -133,8 +131,20 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
return node_eval_result;
}
// Return back from branch func graph.
void StackFrame::Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result) {
// Return back from child func graph.
void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame,
const EvalResultPtr &eval_result) {
// Overwrite the result if func graph is stub.
EvalResultPtr result = eval_result;
if (last_stack_frame->func_graph()->stub()) {
result = std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
}
// Save func graph eval result for specialize.
auto evaluator = last_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result);
// Continue saving node's result for parent func graph.
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();

View File

@ -46,7 +46,7 @@ class StackFrame : public Base {
virtual ~StackFrame() = default;
void Load() {
node_slots = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
node_slots_ = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return EXCLUDE;
}
@ -61,17 +61,17 @@ class StackFrame : public Base {
// Run one step in current func graph.
EvalResultPtr Step(const AnalysisEnginePtr &engine);
// Return back from branch func graph.
void Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result);
void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result);
bool Done() { return done_; }
AnfNodePtr &CurrentNode() {
if (slot_index_ >= node_slots.size()) {
if (slot_index_ >= node_slots_.size()) {
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToString()
<< " is invalid. Try to access frame sequence by index " << slot_index_
<< ", while the size is " << node_slots.size() << ".";
<< ", while the size is " << node_slots_.size() << ".";
}
return node_slots[slot_index_];
return node_slots_[slot_index_];
}
AnfNodePtr &NextNode() {
@ -97,8 +97,8 @@ class StackFrame : public Base {
MS_EXCEPTION_IF_NULL(func_graph_);
std::ostringstream buffer;
buffer << "StackFrame: " << this << ", " << func_graph_->ToString();
if (slot_index_ < node_slots.size()) {
auto current_node = node_slots[slot_index_];
if (slot_index_ < node_slots_.size()) {
auto current_node = node_slots_[slot_index_];
buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")";
} else {
buffer << "(Exhausted..)";
@ -132,7 +132,7 @@ class StackFrame : public Base {
AnalysisContextPtr current_context_;
AnalysisContextPtr parent_context_;
AbstractBasePtrList args_abs_list_;
std::vector<AnfNodePtr> node_slots;
std::vector<AnfNodePtr> node_slots_;
size_t slot_index_;
bool done_;
};

View File

@ -149,7 +149,6 @@ class Adagrad(Optimizer):
super(Adagrad, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(accum, update_slots, self.cls_name)
self.accum = self.parameters.clone(prefix="accum", init=accum)
self.hyper_map = C.HyperMap()
self.update_slots = update_slots
self.opt = P.ApplyAdagrad(update_slots=update_slots)
@ -161,9 +160,9 @@ class Adagrad(Optimizer):
grads = self.scale_grad(grads)
lr = self.get_lr()
if self.is_group_lr:
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
grads)
success = self.map_reverse(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
grads)
else:
success = self.map_(F.partial(_ada_grad_opt, self.opt, lr), params, accum,
grads)
success = self.map_reverse(F.partial(_ada_grad_opt, self.opt, lr), params, accum,
grads)
return success

View File

@ -329,7 +329,6 @@ class Adam(Optimizer):
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self._is_device = True
self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
@ -352,15 +351,15 @@ class Adam(Optimizer):
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power
if self.is_group_lr:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
success = self.map_reverse(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
else:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
success = self.map_reverse(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
return success
@Optimizer.target.setter
@ -460,23 +459,23 @@ class AdamWeightDecay(Optimizer):
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
def construct(self, gradients):
lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1,
self.moments2, gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result
@ -614,8 +613,6 @@ class AdamOffload(Optimizer):
self.use_locking = use_locking
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
self.opt.add_prim_attr("primitive_target", "CPU")
@ -632,11 +629,11 @@ class AdamOffload(Optimizer):
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power
if self.is_group_lr:
success = self.map_(F.partial(_adam_opt, self.opt,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
success = self.map_reverse(F.partial(_adam_opt, self.opt,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
else:
success = self.map_(F.partial(_adam_opt, self.opt,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
success = self.map_reverse(F.partial(_adam_opt, self.opt,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
return success

View File

@ -205,7 +205,6 @@ class FTRL(Optimizer):
self.lr_power = lr_power
if not self.is_group:
self.decay_flags = tuple((lambda: True)() for x in self.parameters)
self.hyper_map = C.HyperMap()
self.opt = P.ApplyFtrl(use_locking=use_locking)
self.use_locking = use_locking
self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
@ -227,9 +226,9 @@ class FTRL(Optimizer):
grads = self._grad_sparse_indices_deduplicate(grads)
lr = self.get_lr()
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.l1, self.l2, self.lr_power, lr),
linear, grads, params, moments, self.ps_parameters, self.cache_enable)
success = self.map_reverse(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.l1, self.l2, self.lr_power, lr),
linear, grads, params, moments, self.ps_parameters, self.cache_enable)
return success
@Optimizer.target.setter

View File

@ -281,7 +281,6 @@ class Lamb(Optimizer):
if not self.dynamic_lr:
self.global_step = Parameter(initializer(0, [1]), name='global_step')
self.assignadd = P.AssignAdd()
self.hyper_map = C.HyperMap()
self.device_ascend = context.get_context("device_target") == "Ascend"
def construct(self, gradients):
@ -290,20 +289,20 @@ class Lamb(Optimizer):
gradients = self.gradients_centralization(gradients)
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step),
lr, self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step),
lr, self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr),
self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr),
self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr, self.weight_decay),
self.params, self.moments1, self.moments2, gradients,
self.decay_flags, self.optim_filter)
optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr, self.weight_decay),
self.params, self.moments1, self.moments2, gradients,
self.decay_flags, self.optim_filter)
if self.use_parallel:
optim_result = F.depend(optim_result, self.broadcast_params(optim_result))

View File

@ -115,7 +115,6 @@ class LARS(Optimizer):
self.decay_flags = optimizer.decay_flags
self.reciprocal_scale = optimizer.reciprocal_scale
self.need_scale = optimizer.need_scale
self.hyper_map = C.HyperMap()
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.cast = P.Cast()

View File

@ -248,8 +248,6 @@ class LazyAdam(Optimizer):
self._is_device = True
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
@ -268,17 +266,18 @@ class LazyAdam(Optimizer):
self.beta2_power = self.beta2_power * self.beta2
if self.is_group_lr:
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
else:
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps,
lr),
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
return success
@Optimizer.target.setter

View File

@ -157,7 +157,6 @@ class Momentum(Optimizer):
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
def construct(self, gradients):
@ -168,9 +167,9 @@ class Momentum(Optimizer):
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
self.ps_parameters, self.cache_enable)
success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum),
lr, gradients, params, moments, self.ps_parameters, self.cache_enable)
else:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments,
self.ps_parameters, self.cache_enable)
success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum, lr),
gradients, params, moments, self.ps_parameters, self.cache_enable)
return success

View File

@ -197,6 +197,9 @@ class Optimizer(Cell):
self.global_step_increase_tensor = Tensor(1, mstype.int32)
self.param_length = len(self.parameters)
self.map_ = C.Map()
self.map_reverse = C.Map(None, True)
self.hyper_map = C.HyperMap()
self.hyper_map_reverse = C.HyperMap(None, True)
self._use_parallel_optimizer()
def _use_parallel_optimizer(self):

View File

@ -162,7 +162,6 @@ class ProximalAdagrad(Optimizer):
self.accum = self.parameters.clone(prefix="accum", init=accum)
self.l1 = Tensor(l1, mstype.float32)
self.l2 = Tensor(l2, mstype.float32)
self.hyper_map = C.HyperMap()
self.use_locking = use_locking
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
@ -176,11 +175,12 @@ class ProximalAdagrad(Optimizer):
grads = self._grad_sparse_indices_deduplicate(grads)
lr = self.get_lr()
if self.is_group_lr:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
grads, params, accum)
success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2),
lr, grads, params, accum)
else:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr),
grads, params, accum)
success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2,
lr),
grads, params, accum)
return success
@Optimizer.target.setter

View File

@ -197,7 +197,6 @@ class RMSProp(Optimizer):
self.momentum = momentum
self.ms = self.parameters.clone(prefix="mean_square", init='ones')
self.moment = self.parameters.clone(prefix="moment", init='zeros')
self.hyper_map = C.HyperMap()
self.epsilon = epsilon
self.decay = decay
@ -209,17 +208,20 @@ class RMSProp(Optimizer):
lr = self.get_lr()
if self.centered:
if self.is_group_lr:
success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
success = self.hyper_map_reverse(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum),
lr, params, self.mg, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)
success = self.hyper_map_reverse(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr),
params, self.mg, self.ms, self.moment, gradients)
else:
if self.is_group_lr:
success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.ms, self.moment, gradients)
success = self.hyper_map_reverse(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum),
lr, params, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.ms, self.moment, gradients)
success = self.hyper_map_reverse(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr),
params, self.ms, self.moment, gradients)
return success

View File

@ -170,7 +170,6 @@ class SGD(Optimizer):
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.accum = self.parameters.clone(prefix="accum", init='zeros')
self.stat = self.parameters.clone(prefix="stat", init='ones')
self.hyper_map = C.HyperMap()
def construct(self, gradients):
params = self.parameters
@ -180,7 +179,9 @@ class SGD(Optimizer):
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum),
lr, gradients, params, accum, stat)
else:
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum, lr),
gradients, params, accum, stat)
return success

View File

@ -318,7 +318,6 @@ class ThorGpu(Optimizer):
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
@ -606,7 +605,6 @@ class ThorAscend(Optimizer):
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))

View File

@ -483,6 +483,7 @@ class HyperMap(HyperMap_):
Args:
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
the operations should be put in the first input of the instance.
reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode.
Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length.
@ -517,20 +518,20 @@ class HyperMap(HyperMap_):
>>> print(output)
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
>>> square_map = HyperMap(square)
>>> square_map = HyperMap(square, False)
>>> output = square_map(nest_tensor_list)
>>> print(output)
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
"""
def __init__(self, ops=None):
def __init__(self, ops=None, reverse=False):
"""Initialize HyperMap."""
self.ops = ops
if ops:
HyperMap_.__init__(self, ops)
HyperMap_.__init__(self, reverse, ops)
else:
HyperMap_.__init__(self)
HyperMap_.__init__(self, reverse)
def __call__(self, *args):
func = self.ops
@ -555,6 +556,7 @@ class Map(Map_):
Args:
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
the operations should be put in the first input of the instance. Default: None
reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode.
Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
@ -581,20 +583,20 @@ class Map(Map_):
>>> print(output)
(Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
Tensor(shape=[], dtype=Float32, value= 9))
>>> square_map = Map(square)
>>> square_map = Map(square, False)
>>> output = square_map(tensor_list)
>>> print(output)
(Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
Tensor(shape=[], dtype=Float32, value= 9))
"""
def __init__(self, ops=None):
def __init__(self, ops=None, reverse=False):
"""Initialize Map."""
self.ops = ops
if ops:
Map_.__init__(self, ops)
Map_.__init__(self, reverse, ops)
else:
Map_.__init__(self)
Map_.__init__(self, reverse)
def __call__(self, *args):
func = self.ops