Add reverse option for HyperMap and Map, and enable reverse option for optimizers.
This commit is contained in:
parent
a7ddd10af5
commit
a3e6009b53
|
@ -123,16 +123,21 @@ void HyperMap::Init() {
|
|||
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
|
||||
}
|
||||
|
||||
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
|
||||
HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
|
||||
: MetaFuncGraph("hyper_map"),
|
||||
fn_leaf_(fn_leaf),
|
||||
reverse_(reverse),
|
||||
broadcast_(false),
|
||||
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
|
||||
Init();
|
||||
}
|
||||
|
||||
HyperMap::HyperMap(const HyperMap &h)
|
||||
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
|
||||
: MetaFuncGraph("hyper_map"),
|
||||
fn_leaf_(h.fn_leaf_),
|
||||
reverse_(h.reverse_),
|
||||
broadcast_(h.broadcast_),
|
||||
nonleaf_(h.nonleaf_) {
|
||||
Init();
|
||||
}
|
||||
|
||||
|
@ -156,12 +161,23 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
||||
std::size_t size = type->elements().size();
|
||||
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
auto lhs = std::static_pointer_cast<List>(item.second);
|
||||
MS_EXCEPTION_IF_NULL(lhs);
|
||||
return lhs->elements().size() != size;
|
||||
});
|
||||
size_t size = type->elements().size();
|
||||
size_t num = 0;
|
||||
bool is_not_same =
|
||||
std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
num++;
|
||||
auto lhs = std::static_pointer_cast<List>(item.second);
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a List, but got "
|
||||
<< item.second->ToString();
|
||||
}
|
||||
if (lhs->elements().size() != size) {
|
||||
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
|
||||
<< lhs->elements().size();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_not_same) {
|
||||
MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
|
||||
}
|
||||
|
@ -169,24 +185,31 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
|
|||
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
||||
// hypermap and graph generated, it will cause memory leak.
|
||||
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
|
||||
constexpr size_t kPrimHoldLen = 1;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(size + kPrimHoldLen);
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
|
||||
for (int64_t i = 0; i < SizeToLong(size); ++i) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
|
||||
std::vector<AnfNodePtr> inputs2;
|
||||
inputs2.push_back(fn_rec);
|
||||
if (fn_arg != nullptr) {
|
||||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
|
||||
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
|
||||
});
|
||||
size_t pos = (reverse_ ? (size - 1 - i) : i);
|
||||
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
|
||||
return func_graph->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
|
||||
});
|
||||
|
||||
auto call_node = func_graph->NewCNodeInOrder(inputs2);
|
||||
inputs.push_back(call_node);
|
||||
if (reverse_) {
|
||||
inputs.insert(inputs.begin() + 1, call_node);
|
||||
} else {
|
||||
inputs.emplace_back(call_node);
|
||||
}
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
@ -196,12 +219,23 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
||||
std::size_t size = type->elements().size();
|
||||
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
auto lhs = std::static_pointer_cast<Tuple>(item.second);
|
||||
MS_EXCEPTION_IF_NULL(lhs);
|
||||
return lhs->elements().size() != size;
|
||||
});
|
||||
size_t size = type->elements().size();
|
||||
size_t num = 0;
|
||||
bool is_not_same =
|
||||
std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
num++;
|
||||
auto lhs = std::static_pointer_cast<Tuple>(item.second);
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a Tuple, but got "
|
||||
<< item.second->ToString();
|
||||
}
|
||||
if (lhs->elements().size() != size) {
|
||||
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
|
||||
<< lhs->elements().size();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_not_same) {
|
||||
MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
|
||||
}
|
||||
|
@ -209,23 +243,31 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
|
|||
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
||||
// hypermap and graph generated, it will cause memory leak.
|
||||
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
|
||||
constexpr size_t kPrimHoldLen = 1;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(size + kPrimHoldLen);
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
||||
for (int64_t i = 0; i < SizeToLong(size); ++i) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
|
||||
std::vector<AnfNodePtr> inputs2;
|
||||
inputs2.push_back(fn_rec);
|
||||
if (fn_arg != nullptr) {
|
||||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
|
||||
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
|
||||
});
|
||||
size_t pos = (reverse_ ? (size - 1 - i) : i);
|
||||
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
|
||||
return func_graph->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
|
||||
});
|
||||
|
||||
auto call_node = func_graph->NewCNodeInOrder(inputs2);
|
||||
inputs.push_back(call_node);
|
||||
if (reverse_) {
|
||||
inputs.insert(inputs.begin() + 1, call_node);
|
||||
} else {
|
||||
inputs.emplace_back(call_node);
|
||||
}
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
@ -235,29 +277,38 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGrap
|
|||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
std::size_t attrSize = type->GetAttributes().size();
|
||||
constexpr size_t kPrimAndTypeLen = 2;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(attrSize + kPrimAndTypeLen);
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
|
||||
inputs.push_back(NewValueNode(type));
|
||||
|
||||
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
||||
// hypermap and graph generated, it will cause memory leak.
|
||||
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
|
||||
std::size_t attrSize = type->GetAttributes().size();
|
||||
for (std::size_t i = 0; i < attrSize; ++i) {
|
||||
for (std::size_t i = 0; i < attrSize; i++) {
|
||||
MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_;
|
||||
std::vector<AnfNodePtr> inputs2;
|
||||
inputs2.push_back(fn_rec);
|
||||
if (fn_arg) {
|
||||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
int64_t j = 0;
|
||||
for (auto item : arg_map) {
|
||||
inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
|
||||
j++;
|
||||
size_t size = arg_map.size();
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
size_t pos = (reverse_ ? (size - 1 - j) : j);
|
||||
auto &item = arg_map[pos];
|
||||
inputs2.push_back(
|
||||
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
|
||||
}
|
||||
|
||||
auto call_node = func_graph->NewCNodeInOrder(inputs2);
|
||||
inputs.push_back(call_node);
|
||||
if (reverse_) {
|
||||
inputs.insert(inputs.begin() + 2, call_node);
|
||||
} else {
|
||||
inputs.emplace_back(call_node);
|
||||
}
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
@ -383,8 +434,9 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList
|
|||
|
||||
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
|
||||
(void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
|
||||
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
|
||||
.def(py::init<>());
|
||||
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
|
||||
py::arg("ops"))
|
||||
.def(py::init<bool>(), py::arg("reverse"));
|
||||
}));
|
||||
|
||||
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
|
||||
|
|
|
@ -48,12 +48,13 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
|
|||
|
||||
class HyperMap : public MetaFuncGraph {
|
||||
public:
|
||||
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
|
||||
explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
|
||||
HyperMap(const HyperMap &h);
|
||||
void Init();
|
||||
HyperMap &operator=(const HyperMap &h) {
|
||||
if (this != &h) {
|
||||
fn_leaf_ = h.fn_leaf_;
|
||||
reverse_ = h.reverse_;
|
||||
broadcast_ = h.broadcast_;
|
||||
nonleaf_ = h.nonleaf_;
|
||||
if (fn_leaf_) {
|
||||
|
@ -82,6 +83,7 @@ class HyperMap : public MetaFuncGraph {
|
|||
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
|
||||
|
||||
MultitypeFuncGraphPtr fn_leaf_;
|
||||
bool reverse_;
|
||||
bool broadcast_;
|
||||
std::set<TypeId> nonleaf_;
|
||||
};
|
||||
|
@ -89,7 +91,8 @@ using HyperMapPtr = std::shared_ptr<HyperMap>;
|
|||
|
||||
class HyperMapPy : public HyperMap {
|
||||
public:
|
||||
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
|
||||
explicit HyperMapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
|
||||
: HyperMap(reverse, fn_leaf) {}
|
||||
~HyperMapPy() override = default;
|
||||
MS_DECLARE_PARENT(HyperMapPy, HyperMap)
|
||||
};
|
||||
|
|
|
@ -71,21 +71,33 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
|
|||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
||||
std::size_t size = type->elements().size();
|
||||
size_t num = 0;
|
||||
bool is_not_same =
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
num++;
|
||||
auto lhs = std::dynamic_pointer_cast<List>(item.second);
|
||||
MS_EXCEPTION_IF_NULL(lhs);
|
||||
return lhs->elements().size() != size;
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a List, but got "
|
||||
<< item.second->ToString();
|
||||
}
|
||||
if (lhs->elements().size() != size) {
|
||||
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
|
||||
<< lhs->elements().size();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_not_same) {
|
||||
MS_LOG(EXCEPTION) << "List in Map should have same length";
|
||||
}
|
||||
|
||||
constexpr size_t kPrimHoldLen = 1;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(size + kPrimHoldLen);
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
|
||||
for (int64_t i = 0; i < SizeToLong(size); ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target";
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_;
|
||||
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
|
||||
auto fn = NewValueNode(ptrGraph);
|
||||
|
||||
|
@ -95,13 +107,19 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
|
|||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
|
||||
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
|
||||
});
|
||||
size_t pos = (reverse_ ? (size - 1 - i) : i);
|
||||
(void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
|
||||
return func_graph->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
|
||||
});
|
||||
|
||||
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
|
||||
auto call_node = func_graph->NewCNodeInOrder(inputs2);
|
||||
if (reverse_) {
|
||||
inputs.insert(inputs.begin() + 1, call_node);
|
||||
} else {
|
||||
inputs.emplace_back(call_node);
|
||||
}
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
@ -111,22 +129,34 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
||||
std::size_t size = type->elements().size();
|
||||
size_t size = type->elements().size();
|
||||
size_t num = 0;
|
||||
bool is_not_same =
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
num++;
|
||||
auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
|
||||
MS_EXCEPTION_IF_NULL(lhs);
|
||||
return lhs->elements().size() != size;
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The elements[" << num - 1 << "] has wrong type, expected a Tuple, but got "
|
||||
<< item.second->ToString();
|
||||
}
|
||||
if (lhs->elements().size() != size) {
|
||||
MS_LOG(ERROR) << "The elements[" << num - 1 << "] has different length, expected " << size << ", but got "
|
||||
<< lhs->elements().size();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_not_same) {
|
||||
MS_LOG(EXCEPTION) << "tuple in Map should have same length";
|
||||
}
|
||||
|
||||
constexpr size_t kPrimHoldLen = 1;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(size + kPrimHoldLen);
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
||||
for (int64_t i = 0; i < SizeToLong(size); ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs";
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_;
|
||||
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
|
||||
auto fn = NewValueNode(ptrGraph);
|
||||
|
||||
|
@ -136,13 +166,19 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
|
|||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
|
||||
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
|
||||
});
|
||||
size_t pos = (reverse_ ? (size - 1 - i) : i);
|
||||
(void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, &pos](const std::pair<AnfNodePtr, Any> &item) {
|
||||
return func_graph->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
|
||||
});
|
||||
|
||||
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
|
||||
auto call_node = func_graph->NewCNodeInOrder(inputs2);
|
||||
if (reverse_) {
|
||||
inputs.insert(inputs.begin() + 1, call_node);
|
||||
} else {
|
||||
inputs.emplace_back(call_node);
|
||||
}
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
@ -152,13 +188,15 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
|
|||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
size_t attrSize = type->GetAttributes().size();
|
||||
constexpr size_t kPrimAndTypeLen = 2;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(attrSize + kPrimAndTypeLen);
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
|
||||
inputs.push_back(NewValueNode(type));
|
||||
|
||||
std::size_t attrSize = type->GetAttributes().size();
|
||||
for (std::size_t i = 0; i < attrSize; ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs";
|
||||
for (size_t i = 0; i < attrSize; i++) {
|
||||
MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_;
|
||||
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
|
||||
auto fn = NewValueNode(ptrGraph);
|
||||
|
||||
|
@ -168,13 +206,20 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
|
|||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
int64_t j = 0;
|
||||
for (auto item : arg_pairs) {
|
||||
inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
|
||||
j++;
|
||||
size_t size = arg_pairs.size();
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
size_t pos = (reverse_ ? (size - 1 - j) : j);
|
||||
auto &item = arg_pairs[pos];
|
||||
inputs2.push_back(
|
||||
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
|
||||
}
|
||||
|
||||
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
|
||||
auto call_node = func_graph->NewCNodeInOrder(inputs2);
|
||||
if (reverse_) {
|
||||
inputs.insert(inputs.begin() + 2, call_node);
|
||||
} else {
|
||||
inputs.emplace_back(call_node);
|
||||
}
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
@ -284,8 +329,9 @@ abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args
|
|||
|
||||
REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
|
||||
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
|
||||
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
|
||||
.def(py::init<>());
|
||||
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
|
||||
py::arg("ops"))
|
||||
.def(py::init<bool>(), py::arg("reverse"));
|
||||
}));
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,21 +33,28 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
|
|||
|
||||
class Map : public MetaFuncGraph {
|
||||
public:
|
||||
explicit Map(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
|
||||
explicit Map(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
|
||||
: MetaFuncGraph("map"),
|
||||
fn_leaf_(fn_leaf),
|
||||
reverse_(reverse),
|
||||
broadcast_(false),
|
||||
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
|
||||
Init();
|
||||
}
|
||||
Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
|
||||
Map(const Map &map)
|
||||
: MetaFuncGraph("map"),
|
||||
fn_leaf_(map.fn_leaf_),
|
||||
reverse_(map.reverse_),
|
||||
broadcast_(map.broadcast_),
|
||||
nonleaf_(map.nonleaf_) {
|
||||
Init();
|
||||
}
|
||||
Map &operator=(const Map &h) {
|
||||
if (this != &h) {
|
||||
fn_leaf_ = h.fn_leaf_;
|
||||
broadcast_ = h.broadcast_;
|
||||
nonleaf_ = h.nonleaf_;
|
||||
Map &operator=(const Map &map) {
|
||||
if (this != &map) {
|
||||
fn_leaf_ = map.fn_leaf_;
|
||||
reverse_ = map.reverse_;
|
||||
broadcast_ = map.broadcast_;
|
||||
nonleaf_ = map.nonleaf_;
|
||||
if (fn_leaf_) {
|
||||
name_ = "map[" + fn_leaf_->name() + "]";
|
||||
}
|
||||
|
@ -81,13 +88,15 @@ class Map : public MetaFuncGraph {
|
|||
}
|
||||
|
||||
MultitypeFuncGraphPtr fn_leaf_;
|
||||
bool reverse_;
|
||||
bool broadcast_;
|
||||
std::set<TypeId> nonleaf_;
|
||||
};
|
||||
using MapPtr = std::shared_ptr<Map>;
|
||||
class MapPy : public Map {
|
||||
public:
|
||||
explicit MapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : Map(fn_leaf) {}
|
||||
explicit MapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
|
||||
: Map(reverse, fn_leaf) {}
|
||||
~MapPy() override = default;
|
||||
MS_DECLARE_PARENT(MapPy, Map)
|
||||
};
|
||||
|
|
|
@ -127,22 +127,13 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
<< ", res_base: " << res_base->ToString();
|
||||
break;
|
||||
}
|
||||
|
||||
// Overwrite the result if func graph is stub.
|
||||
if (current_stack_frame->func_graph()->stub()) {
|
||||
eval_result = std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
|
||||
}
|
||||
// Save func graph eval result for specialize.
|
||||
auto evaluator = current_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
evaluator->evaluator_cache_mgr()->SetValue(current_stack_frame->args_abs_list(), eval_result);
|
||||
|
||||
// Leave current func graph.
|
||||
LeaveStackFrame(engine, current_stack_frame);
|
||||
// Switch the stack frame.
|
||||
auto last_stack_frame = current_stack_frame;
|
||||
current_stack_frame = stack_frames.top();
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
|
||||
current_stack_frame->Back(engine, eval_result);
|
||||
current_stack_frame->Back(engine, last_stack_frame, eval_result);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -223,7 +214,6 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
|
||||
<< fg->ToString() << "();";
|
||||
}
|
||||
|
||||
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
||||
const auto ¶meters = fg->parameters();
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
|
@ -399,8 +389,7 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
|
|||
|
||||
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
const string evaluator_name = ToString();
|
||||
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_loc_, std::try_to_lock);
|
||||
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_lock_, std::try_to_lock);
|
||||
if (!eval_lock.owns_lock()) {
|
||||
auto py_tstate = PyEval_SaveThread();
|
||||
eval_lock.try_lock_for(std::chrono::seconds(kInferTimeout));
|
||||
|
@ -420,26 +409,26 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
|
|||
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
||||
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
|
||||
MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||
const std::string &evaluator_name = ToString();
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result == nullptr) {
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
||||
EvalResultPtr res = Eval(engine, args_spec_list);
|
||||
if (res->abstract() == nullptr) {
|
||||
eval_result = Eval(engine, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
if (eval_result->abstract() == nullptr) {
|
||||
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
||||
}
|
||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << res->abstract()->ToString() << ".";
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return res;
|
||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << eval_result->abstract()->ToString() << ".";
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << ".";
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return eval_result;
|
||||
}
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
|
|
|
@ -40,9 +40,9 @@ using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>;
|
|||
class Evaluator : public Base {
|
||||
public:
|
||||
explicit Evaluator(const std::string &id)
|
||||
: evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
|
||||
attr_cache_(std::make_shared<EvaluatorAttrCache>()),
|
||||
identifier_(id) {}
|
||||
: identifier_(id),
|
||||
evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
|
||||
attr_cache_(std::make_shared<EvaluatorAttrCache>()) {}
|
||||
~Evaluator() override = default;
|
||||
MS_DECLARE_PARENT(Evaluator, Base);
|
||||
|
||||
|
@ -92,13 +92,16 @@ class Evaluator : public Base {
|
|||
EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
|
||||
EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
|
||||
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr_;
|
||||
EvaluatorAttrCachePtr attr_cache_;
|
||||
std::recursive_timed_mutex &eval_lock() { return eval_lock_; }
|
||||
|
||||
protected:
|
||||
std::string identifier_;
|
||||
|
||||
AnfNodeWeakPtr bound_node_;
|
||||
std::recursive_timed_mutex eval_loc_;
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr_;
|
||||
std::recursive_timed_mutex eval_lock_;
|
||||
|
||||
private:
|
||||
EvaluatorAttrCachePtr attr_cache_;
|
||||
};
|
||||
|
||||
class PrimEvaluator : public Evaluator {
|
||||
|
|
|
@ -65,7 +65,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
|
||||
// Evaluate the inputs firstly. Build arguments for the func graph.
|
||||
AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
|
||||
|
||||
// Check if already evaluated before.
|
||||
if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) {
|
||||
return nullptr;
|
||||
|
@ -80,7 +79,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
<< fg->parameters().size() << ", but the number of provided arguments is "
|
||||
<< args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
|
||||
|
@ -133,8 +131,20 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
|||
return node_eval_result;
|
||||
}
|
||||
|
||||
// Return back from branch func graph.
|
||||
void StackFrame::Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result) {
|
||||
// Return back from child func graph.
|
||||
void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame,
|
||||
const EvalResultPtr &eval_result) {
|
||||
// Overwrite the result if func graph is stub.
|
||||
EvalResultPtr result = eval_result;
|
||||
if (last_stack_frame->func_graph()->stub()) {
|
||||
result = std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
|
||||
}
|
||||
// Save func graph eval result for specialize.
|
||||
auto evaluator = last_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result);
|
||||
|
||||
// Continue saving node's result for parent func graph.
|
||||
auto ¤t_node = NextNode();
|
||||
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
|
|
|
@ -46,7 +46,7 @@ class StackFrame : public Base {
|
|||
virtual ~StackFrame() = default;
|
||||
|
||||
void Load() {
|
||||
node_slots = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
|
||||
node_slots_ = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
|
||||
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
@ -61,17 +61,17 @@ class StackFrame : public Base {
|
|||
// Run one step in current func graph.
|
||||
EvalResultPtr Step(const AnalysisEnginePtr &engine);
|
||||
// Return back from branch func graph.
|
||||
void Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result);
|
||||
void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result);
|
||||
|
||||
bool Done() { return done_; }
|
||||
|
||||
AnfNodePtr &CurrentNode() {
|
||||
if (slot_index_ >= node_slots.size()) {
|
||||
if (slot_index_ >= node_slots_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToString()
|
||||
<< " is invalid. Try to access frame sequence by index " << slot_index_
|
||||
<< ", while the size is " << node_slots.size() << ".";
|
||||
<< ", while the size is " << node_slots_.size() << ".";
|
||||
}
|
||||
return node_slots[slot_index_];
|
||||
return node_slots_[slot_index_];
|
||||
}
|
||||
|
||||
AnfNodePtr &NextNode() {
|
||||
|
@ -97,8 +97,8 @@ class StackFrame : public Base {
|
|||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
std::ostringstream buffer;
|
||||
buffer << "StackFrame: " << this << ", " << func_graph_->ToString();
|
||||
if (slot_index_ < node_slots.size()) {
|
||||
auto current_node = node_slots[slot_index_];
|
||||
if (slot_index_ < node_slots_.size()) {
|
||||
auto current_node = node_slots_[slot_index_];
|
||||
buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")";
|
||||
} else {
|
||||
buffer << "(Exhausted..)";
|
||||
|
@ -132,7 +132,7 @@ class StackFrame : public Base {
|
|||
AnalysisContextPtr current_context_;
|
||||
AnalysisContextPtr parent_context_;
|
||||
AbstractBasePtrList args_abs_list_;
|
||||
std::vector<AnfNodePtr> node_slots;
|
||||
std::vector<AnfNodePtr> node_slots_;
|
||||
size_t slot_index_;
|
||||
bool done_;
|
||||
};
|
||||
|
|
|
@ -149,7 +149,6 @@ class Adagrad(Optimizer):
|
|||
super(Adagrad, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
_check_param_value(accum, update_slots, self.cls_name)
|
||||
self.accum = self.parameters.clone(prefix="accum", init=accum)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.update_slots = update_slots
|
||||
self.opt = P.ApplyAdagrad(update_slots=update_slots)
|
||||
|
||||
|
@ -161,9 +160,9 @@ class Adagrad(Optimizer):
|
|||
grads = self.scale_grad(grads)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
|
||||
grads)
|
||||
success = self.map_reverse(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
|
||||
grads)
|
||||
else:
|
||||
success = self.map_(F.partial(_ada_grad_opt, self.opt, lr), params, accum,
|
||||
grads)
|
||||
success = self.map_reverse(F.partial(_ada_grad_opt, self.opt, lr), params, accum,
|
||||
grads)
|
||||
return success
|
||||
|
|
|
@ -329,7 +329,6 @@ class Adam(Optimizer):
|
|||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self._is_device = True
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
|
||||
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
@ -352,15 +351,15 @@ class Adam(Optimizer):
|
|||
beta2_power = self.beta2_power * self.beta2
|
||||
self.beta2_power = beta2_power
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
|
||||
success = self.map_reverse(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
|
||||
else:
|
||||
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
|
||||
success = self.map_reverse(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
@ -460,23 +459,23 @@ class AdamWeightDecay(Optimizer):
|
|||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1,
|
||||
self.moments2, gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.hyper_map_reverse(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
||||
|
@ -614,8 +613,6 @@ class AdamOffload(Optimizer):
|
|||
self.use_locking = use_locking
|
||||
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
|
@ -632,11 +629,11 @@ class AdamOffload(Optimizer):
|
|||
beta2_power = self.beta2_power * self.beta2
|
||||
self.beta2_power = beta2_power
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_adam_opt, self.opt,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
success = self.map_reverse(F.partial(_adam_opt, self.opt,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
else:
|
||||
success = self.map_(F.partial(_adam_opt, self.opt,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2)
|
||||
success = self.map_reverse(F.partial(_adam_opt, self.opt,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2)
|
||||
return success
|
||||
|
|
|
@ -205,7 +205,6 @@ class FTRL(Optimizer):
|
|||
self.lr_power = lr_power
|
||||
if not self.is_group:
|
||||
self.decay_flags = tuple((lambda: True)() for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyFtrl(use_locking=use_locking)
|
||||
self.use_locking = use_locking
|
||||
self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
|
||||
|
@ -227,9 +226,9 @@ class FTRL(Optimizer):
|
|||
grads = self._grad_sparse_indices_deduplicate(grads)
|
||||
lr = self.get_lr()
|
||||
|
||||
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.l1, self.l2, self.lr_power, lr),
|
||||
linear, grads, params, moments, self.ps_parameters, self.cache_enable)
|
||||
success = self.map_reverse(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.l1, self.l2, self.lr_power, lr),
|
||||
linear, grads, params, moments, self.ps_parameters, self.cache_enable)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
|
|
@ -281,7 +281,6 @@ class Lamb(Optimizer):
|
|||
if not self.dynamic_lr:
|
||||
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.device_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
def construct(self, gradients):
|
||||
|
@ -290,20 +289,20 @@ class Lamb(Optimizer):
|
|||
gradients = self.gradients_centralization(gradients)
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step),
|
||||
lr, self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step),
|
||||
lr, self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr),
|
||||
self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr),
|
||||
self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr, self.weight_decay),
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flags, self.optim_filter)
|
||||
optim_result = self.hyper_map_reverse(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr, self.weight_decay),
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flags, self.optim_filter)
|
||||
|
||||
if self.use_parallel:
|
||||
optim_result = F.depend(optim_result, self.broadcast_params(optim_result))
|
||||
|
|
|
@ -115,7 +115,6 @@ class LARS(Optimizer):
|
|||
self.decay_flags = optimizer.decay_flags
|
||||
self.reciprocal_scale = optimizer.reciprocal_scale
|
||||
self.need_scale = optimizer.need_scale
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
|
|
@ -248,8 +248,6 @@ class LazyAdam(Optimizer):
|
|||
self._is_device = True
|
||||
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
|
||||
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
@ -268,17 +266,18 @@ class LazyAdam(Optimizer):
|
|||
self.beta2_power = self.beta2_power * self.beta2
|
||||
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
|
||||
self.cache_enable)
|
||||
success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
|
||||
self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
|
||||
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
|
||||
self.cache_enable)
|
||||
else:
|
||||
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
|
||||
self.cache_enable)
|
||||
success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
|
||||
self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
|
||||
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps,
|
||||
lr),
|
||||
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
|
||||
self.cache_enable)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
|
|
@ -157,7 +157,6 @@ class Momentum(Optimizer):
|
|||
self.params = self.parameters
|
||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
|
||||
def construct(self, gradients):
|
||||
|
@ -168,9 +167,9 @@ class Momentum(Optimizer):
|
|||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
|
||||
self.ps_parameters, self.cache_enable)
|
||||
success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum),
|
||||
lr, gradients, params, moments, self.ps_parameters, self.cache_enable)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments,
|
||||
self.ps_parameters, self.cache_enable)
|
||||
success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum, lr),
|
||||
gradients, params, moments, self.ps_parameters, self.cache_enable)
|
||||
return success
|
||||
|
|
|
@ -197,6 +197,9 @@ class Optimizer(Cell):
|
|||
self.global_step_increase_tensor = Tensor(1, mstype.int32)
|
||||
self.param_length = len(self.parameters)
|
||||
self.map_ = C.Map()
|
||||
self.map_reverse = C.Map(None, True)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.hyper_map_reverse = C.HyperMap(None, True)
|
||||
self._use_parallel_optimizer()
|
||||
|
||||
def _use_parallel_optimizer(self):
|
||||
|
|
|
@ -160,7 +160,6 @@ class ProximalAdagrad(Optimizer):
|
|||
self.accum = self.parameters.clone(prefix="accum", init=accum)
|
||||
self.l1 = Tensor(l1, mstype.float32)
|
||||
self.l2 = Tensor(l2, mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.use_locking = use_locking
|
||||
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
|
||||
self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
|
||||
|
@ -174,11 +173,12 @@ class ProximalAdagrad(Optimizer):
|
|||
grads = self._grad_sparse_indices_deduplicate(grads)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
|
||||
grads, params, accum)
|
||||
success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2),
|
||||
lr, grads, params, accum)
|
||||
else:
|
||||
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr),
|
||||
grads, params, accum)
|
||||
success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2,
|
||||
lr),
|
||||
grads, params, accum)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
|
|
@ -197,7 +197,6 @@ class RMSProp(Optimizer):
|
|||
self.momentum = momentum
|
||||
self.ms = self.parameters.clone(prefix="mean_square", init='ones')
|
||||
self.moment = self.parameters.clone(prefix="moment", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.epsilon = epsilon
|
||||
self.decay = decay
|
||||
|
||||
|
@ -209,17 +208,20 @@ class RMSProp(Optimizer):
|
|||
lr = self.get_lr()
|
||||
if self.centered:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
|
||||
success = self.hyper_map_reverse(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum),
|
||||
lr, params, self.mg, self.ms, self.moment, gradients)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)
|
||||
|
||||
success = self.hyper_map_reverse(F.partial(_centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum, lr),
|
||||
params, self.mg, self.ms, self.moment, gradients)
|
||||
else:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum), lr, params, self.ms, self.moment, gradients)
|
||||
success = self.hyper_map_reverse(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum),
|
||||
lr, params, self.ms, self.moment, gradients)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum, lr), params, self.ms, self.moment, gradients)
|
||||
success = self.hyper_map_reverse(F.partial(_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum, lr),
|
||||
params, self.ms, self.moment, gradients)
|
||||
return success
|
||||
|
|
|
@ -170,7 +170,6 @@ class SGD(Optimizer):
|
|||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.accum = self.parameters.clone(prefix="accum", init='zeros')
|
||||
self.stat = self.parameters.clone(prefix="stat", init='ones')
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
|
@ -180,7 +179,9 @@ class SGD(Optimizer):
|
|||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
|
||||
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum),
|
||||
lr, gradients, params, accum, stat)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
|
||||
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum, lr),
|
||||
gradients, params, accum, stat)
|
||||
return success
|
||||
|
|
|
@ -318,7 +318,6 @@ class ThorGpu(Optimizer):
|
|||
self.params = self.parameters
|
||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
self.net = net
|
||||
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
|
||||
|
@ -606,7 +605,6 @@ class ThorAscend(Optimizer):
|
|||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.net = net
|
||||
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
|
||||
|
|
|
@ -483,6 +483,7 @@ class HyperMap(HyperMap_):
|
|||
Args:
|
||||
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
||||
the operations should be put in the first input of the instance.
|
||||
reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode.
|
||||
|
||||
Inputs:
|
||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length.
|
||||
|
@ -517,20 +518,20 @@ class HyperMap(HyperMap_):
|
|||
>>> print(output)
|
||||
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
|
||||
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
|
||||
>>> square_map = HyperMap(square)
|
||||
>>> square_map = HyperMap(square, False)
|
||||
>>> output = square_map(nest_tensor_list)
|
||||
>>> print(output)
|
||||
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
|
||||
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
|
||||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
def __init__(self, ops=None, reverse=False):
|
||||
"""Initialize HyperMap."""
|
||||
self.ops = ops
|
||||
if ops:
|
||||
HyperMap_.__init__(self, ops)
|
||||
HyperMap_.__init__(self, reverse, ops)
|
||||
else:
|
||||
HyperMap_.__init__(self)
|
||||
HyperMap_.__init__(self, reverse)
|
||||
|
||||
def __call__(self, *args):
|
||||
func = self.ops
|
||||
|
@ -555,6 +556,7 @@ class Map(Map_):
|
|||
Args:
|
||||
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
||||
the operations should be put in the first input of the instance. Default: None
|
||||
reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode.
|
||||
|
||||
Inputs:
|
||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
||||
|
@ -581,20 +583,20 @@ class Map(Map_):
|
|||
>>> print(output)
|
||||
(Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
|
||||
Tensor(shape=[], dtype=Float32, value= 9))
|
||||
>>> square_map = Map(square)
|
||||
>>> square_map = Map(square, False)
|
||||
>>> output = square_map(tensor_list)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
|
||||
Tensor(shape=[], dtype=Float32, value= 9))
|
||||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
def __init__(self, ops=None, reverse=False):
|
||||
"""Initialize Map."""
|
||||
self.ops = ops
|
||||
if ops:
|
||||
Map_.__init__(self, ops)
|
||||
Map_.__init__(self, reverse, ops)
|
||||
else:
|
||||
Map_.__init__(self)
|
||||
Map_.__init__(self, reverse)
|
||||
|
||||
def __call__(self, *args):
|
||||
func = self.ops
|
||||
|
|
Loading…
Reference in New Issue