!27854 Enable hypermap to use tuple/list as input

Merge pull request !27854 from LiangZhibo/hypermap
This commit is contained in:
i-robot 2022-02-10 03:07:02 +00:00 committed by Gitee
commit e03d1bfb82
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 119 additions and 156 deletions

View File

@ -43,7 +43,6 @@ using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractEllipsis;
@ -75,20 +74,12 @@ void HyperMap::Init() {
}
HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
: MetaFuncGraph("hyper_map"),
fn_leaf_(fn_leaf),
reverse_(reverse),
broadcast_(false),
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
: MetaFuncGraph("hyper_map"), fn_leaf_(fn_leaf), reverse_(reverse), nonleaf_({kObjectTypeList, kObjectTypeTuple}) {
Init();
}
HyperMap::HyperMap(const HyperMap &h)
: MetaFuncGraph("hyper_map"),
fn_leaf_(h.fn_leaf_),
reverse_(h.reverse_),
broadcast_(h.broadcast_),
nonleaf_(h.nonleaf_) {
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), reverse_(h.reverse_), nonleaf_(h.nonleaf_) {
Init();
}
@ -252,61 +243,21 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
return empty_tuple;
}
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph);
std::size_t attrSize = type->GetAttributes().size();
constexpr size_t kPrimAndTypeLen = 2;
std::vector<AnfNodePtr> inputs;
inputs.reserve(attrSize + kPrimAndTypeLen);
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
inputs.push_back(NewValueNode(type));
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
// hypermap and graph generated, it will cause memory leak.
auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
for (std::size_t i = 0; i < attrSize; i++) {
MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_;
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn_rec);
if (fn_arg) {
inputs2.push_back(fn_arg);
}
size_t size = arg_map.size();
for (size_t j = 0; j < size; j++) {
size_t pos = (reverse_ ? (size - 1 - j) : j);
auto &item = arg_map[pos];
inputs2.push_back(
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
}
auto call_node = func_graph->NewCNodeInOrder(inputs2);
if (reverse_) {
inputs.insert(inputs.begin() + kPrimAndTypeLen, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
bool found = false;
bool is_leaf = false;
TypeId id = kObjectTypeEnd;
std::pair<AnfNodePtr, TypePtr> pair;
for (auto &item : arg_map) {
pair = item;
id = item.second->type_id();
if (nonleaf_.count(id)) {
found = true;
// The graph building reaches the leaf situation when there exists type that can not be divided any more.
if (!nonleaf_.count(id)) {
is_leaf = true;
break;
}
}
if (found) {
if (!is_leaf) {
// In a nonleaf situation, all arguments must have the same generic.
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
if (item.first != pair.first) {
@ -333,7 +284,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a
++idx;
oss << "The type of the " << str_index << " argument in HyperMap is " << item.second->ToString() << ".\n";
}
MS_LOG(EXCEPTION) << "The types of arguments in HyperMap must be consistent, "
MS_LOG(EXCEPTION) << "In a nonleaf situation, the types of arguments in HyperMap must be consistent, "
<< "but the types of arguments are inconsistent.\n"
<< oss.str();
}
@ -348,36 +299,11 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a
auto type = std::static_pointer_cast<Tuple>(pair.second);
return FullMake(type, func_graph, fn_arg, arg_map);
}
case kObjectTypeClass: {
auto type = std::static_pointer_cast<Class>(pair.second);
return FullMake(type, func_graph, fn_arg, arg_map);
}
default:
return FullMake(func_graph, fn_arg, arg_map);
}
}
ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
TypePtr type_tensor = std::make_shared<TensorType>();
bool flag = std::any_of(
args_spec_list.begin(), args_spec_list.end(),
[type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
if (flag && broadcast_) {
ArgsPairList ret;
for (auto &item : args_spec_list) {
if (!IsSubType(item.second, type_tensor)) {
TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}),
type_tensor_ele));
} else {
ret.push_back(std::make_pair(item.first, item.second));
}
}
return ret;
}
return args_spec_list;
}
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
@ -387,7 +313,6 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
AnfNodePtr ptrFnArg = nullptr;
std::size_t i = 0;
ArgsPairList argmap;
ArgsPairList argmap2;
if (fn_leaf_ == nullptr) {
ptrFnArg = ptr_graph->add_parameter();
i = 1;
@ -398,8 +323,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
}
argmap2 = Harmonize(ptr_graph, argmap);
ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap));
return ptr_graph;
}

View File

@ -56,7 +56,6 @@ class HyperMap : public MetaFuncGraph {
if (this != &h) {
fn_leaf_ = h.fn_leaf_;
reverse_ = h.reverse_;
broadcast_ = h.broadcast_;
nonleaf_ = h.nonleaf_;
if (fn_leaf_) {
name_ = "hyper_map[" + fn_leaf_->name() + "]";
@ -77,15 +76,11 @@ class HyperMap : public MetaFuncGraph {
const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map);
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
std::pair<std::string, std::string> GetHyperMapInputIndex(size_t num);
MultitypeFuncGraphPtr fn_leaf_;
bool reverse_;
bool broadcast_;
std::set<TypeId> nonleaf_;
};
using HyperMapPtr = std::shared_ptr<HyperMap>;

View File

@ -206,48 +206,6 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
return func_graph->NewCNodeInOrder(inputs);
}
AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph);
size_t attrSize = type->GetAttributes().size();
constexpr size_t kPrimAndTypeLen = 2;
std::vector<AnfNodePtr> inputs;
inputs.reserve(attrSize + kPrimAndTypeLen);
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
inputs.push_back(NewValueNode(type));
for (size_t i = 0; i < attrSize; i++) {
MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_ << ".";
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn);
if (fn_arg != nullptr) {
inputs2.push_back(fn_arg);
}
size_t size = arg_pairs.size();
for (size_t j = 0; j < size; j++) {
size_t pos = (reverse_ ? (size - 1 - j) : j);
auto &item = arg_pairs[pos];
inputs2.push_back(
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
}
auto call_node = func_graph->NewCNodeInOrder(inputs2);
if (reverse_) {
constexpr auto kCallNodePosition = 2;
(void)inputs.insert(inputs.begin() + kCallNodePosition, call_node);
} else {
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);
}
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
if (arg_pairs.empty()) {
MS_EXCEPTION(TypeError) << "The Map operator must have at least two arguments. But the size of arguments is "
@ -308,13 +266,8 @@ AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, c
auto type = std::static_pointer_cast<Tuple>(pair.second);
return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
}
case kObjectTypeClass: {
auto type = std::static_pointer_cast<Class>(pair.second);
return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
}
default:
MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class, but got " << pair.second->ToString()
<< ".";
MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple, but got " << pair.second->ToString() << ".";
}
}

View File

@ -39,7 +39,7 @@ class Map : public MetaFuncGraph {
fn_leaf_(fn_leaf),
reverse_(reverse),
broadcast_(false),
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
nonleaf_({kObjectTypeList, kObjectTypeTuple}) {
Init();
}
Map(const Map &map)
@ -75,8 +75,6 @@ class Map : public MetaFuncGraph {
const ArgsPairList &arg_pairs);
AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_pairs);
AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_pairs);
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
std::pair<std::string, std::string> GetMapInputIndex(size_t num);
void Init() {

View File

@ -614,6 +614,9 @@ class HyperMap(HyperMap_):
If `ops` is `None`, the first input is the operation, and the others are inputs.
Note:
Except for the operation input, the number of inputs should be equal to the number of inputs to `ops`.
Outputs:
Sequence or nested sequence, the sequence of output after applying the function.
e.g. `operation(args[0][i], args[1][i])`.

View File

@ -0,0 +1,90 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import context, nn, Tensor
from mindspore import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE)
single_element_fg = C.MultitypeFuncGraph("single_element_fg")
@single_element_fg.register("Tensor")
def single_element_fg_for_tensor(x):
return P.Square()(x)
double_elements_fg = C.MultitypeFuncGraph("double_elements_fg")
@double_elements_fg.register("Tensor", "Tuple")
def double_elements_fg_for_tensor_tuple(x, y):
return P.Tile()(x, y)
class HyperMapNet(nn.Cell):
def __init__(self, fg):
super(HyperMapNet, self).__init__()
self.common_map = C.HyperMap()
self.fg = fg
def construct(self, nest_tensor_list):
output = self.common_map(self.fg, *nest_tensor_list)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_element_hypermap():
"""
Feature: HyperMap
Description: Test whether the HyperMap with single tensor input can run successfully.
Expectation: success.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
common_map = HyperMapNet(single_element_fg)
output = common_map((x,))
expect_output_1 = np.array([1.0, 4.0, 9.0])
expect_output_2 = np.array([16.0, 25.0, 36.0])
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], Tensor)
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap():
"""
Feature: HyperMap
Description: Test whether the HyperMap with tensor and tuple inputs can run successfully.
Expectation: success.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ((1, 2), (2, 1))
common_map = HyperMapNet(double_elements_fg)
output = common_map((x, y))
expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
expect_output_2 = np.array([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]])
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], Tensor)
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)

View File

@ -38,7 +38,7 @@ def test_hypermap_noleaf_tuple_list_mix():
"""
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
with pytest.raises(Exception, match="The types of arguments in HyperMap must be consistent"):
with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"):
main_noleaf((tensor1, 1), [tensor2, 2])
@ -74,7 +74,7 @@ def test_hypermap_noleaf_list_tuple():
"""
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
with pytest.raises(Exception, match="The types of arguments in HyperMap must be consistent"):
with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"):
main_noleaf([tensor1], (tensor2, tensor2))
@ -87,14 +87,14 @@ def test_tuple_slice_stop_index():
class TupleSliceNet(Cell):
def __init__(self):
super(TupleSliceNet, self).__init__()
self.addN = P.AddN()
self.addn = P.AddN()
self.index_0 = Tensor(3)
def construct(self, tensor_tuple):
tensor_tuple_slice0 = tensor_tuple[:]
tensor_tuple_slice1 = tensor_tuple[self.index_0:"str"] # slice should be Scalar or None, rather than string
sum0 = self.addN(tensor_tuple_slice0)
sum1 = self.addN(tensor_tuple_slice1)
sum0 = self.addn(tensor_tuple_slice0)
sum1 = self.addn(tensor_tuple_slice1)
ret = sum0 + sum1
return ret
@ -120,7 +120,7 @@ def test_tuple_slice_start_index():
class TupleSliceNet(Cell):
def __init__(self):
super(TupleSliceNet, self).__init__()
self.addN = P.AddN()
self.addn = P.AddN()
self.index_0 = Tensor(3)
self.index_1 = Tensor([5])
self.index_3 = Tensor([True])
@ -130,10 +130,10 @@ def test_tuple_slice_start_index():
tensor_tuple_slice1 = tensor_tuple["str":self.index_0]
tensor_tuple_slice2 = tensor_tuple[self.index_3:]
tensor_tuple_slice3 = tensor_tuple[2:self.index_1:]
sum0 = self.addN(tensor_tuple_slice0)
sum1 = self.addN(tensor_tuple_slice1)
sum2 = self.addN(tensor_tuple_slice2)
sum3 = self.addN(tensor_tuple_slice3)
sum0 = self.addn(tensor_tuple_slice0)
sum1 = self.addn(tensor_tuple_slice1)
sum2 = self.addn(tensor_tuple_slice2)
sum3 = self.addn(tensor_tuple_slice3)
ret = sum0 + sum1 + sum2 + sum3
return ret
@ -159,7 +159,7 @@ def test_tuple_slice_step():
class TupleSliceNet(Cell):
def __init__(self):
super(TupleSliceNet, self).__init__()
self.addN = P.AddN()
self.addn = P.AddN()
self.index_0 = Tensor(3)
self.index_1 = Tensor([5])
self.index_3 = Tensor([True])
@ -169,10 +169,10 @@ def test_tuple_slice_step():
tensor_tuple_slice1 = tensor_tuple[:self.index_0]
tensor_tuple_slice2 = tensor_tuple[self.index_3:]
tensor_tuple_slice3 = tensor_tuple[2:self.index_1:0]
sum0 = self.addN(tensor_tuple_slice0)
sum1 = self.addN(tensor_tuple_slice1)
sum2 = self.addN(tensor_tuple_slice2)
sum3 = self.addN(tensor_tuple_slice3)
sum0 = self.addn(tensor_tuple_slice0)
sum1 = self.addn(tensor_tuple_slice1)
sum2 = self.addn(tensor_tuple_slice2)
sum3 = self.addn(tensor_tuple_slice3)
ret = sum0 + sum1 + sum2 + sum3
return ret