forked from mindspore-Ecosystem/mindspore
!3048 support use valuelist or valuetuple of primitives
Merge pull request !3048 from amongo/SupportPrimitiveList
This commit is contained in:
commit
45ad430af2
|
@ -168,15 +168,15 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsAllGraphInValueSequence(const std::vector<ValuePtr> &value_vec) {
|
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
|
||||||
for (auto &elem : value_vec) {
|
for (auto &elem : value_vec) {
|
||||||
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
|
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
|
||||||
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
|
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
|
||||||
auto is_graph = IsAllGraphInValueSequence(vec);
|
auto is_graph = IsAllFuncInValueSequence(vec);
|
||||||
if (!is_graph) {
|
if (!is_graph) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else if (!elem->isa<FuncGraph>()) {
|
} else if (!elem->isa<FuncGraph>() && !elem->isa<Primitive>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,6 +196,8 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
|
||||||
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
|
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
|
||||||
manager->AddFuncGraph(new_fg);
|
manager->AddFuncGraph(new_fg);
|
||||||
node = NewValueNode(new_fg);
|
node = NewValueNode(new_fg);
|
||||||
|
} else if (elem->isa<Primitive>()) {
|
||||||
|
node = NewValueNode(elem);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
|
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString();
|
||||||
}
|
}
|
||||||
|
@ -205,19 +207,21 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
|
||||||
return cnode;
|
return cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
// transform the ValueTuple or ValueList of graph node to make tuple of const graph node
|
// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node
|
||||||
bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||||
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
|
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value());
|
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value());
|
||||||
if (!IsAllGraphInValueSequence(value_vec)) {
|
if (!IsAllFuncInValueSequence(value_vec)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
||||||
// So if has graph in list, try to replace the node with make tuple of graph value node.
|
// So if has graph in list, try to replace the node with make tuple of graph value node.
|
||||||
// we do this because the graphmanger won't investigate the graph inside valuetuple,
|
// we do this because the graphmanger won't investigate the graph inside valuetuple,
|
||||||
// change the vector of graph to be make_tuple of graph value node
|
// change the vector of graph to be make_tuple of graph value node.
|
||||||
|
// (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all
|
||||||
|
// independent nodes.
|
||||||
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
|
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
|
||||||
// replace the ret ptr to be make tuple of graph value node
|
// replace the ret ptr to be make tuple of graph value node
|
||||||
*transformed = node_tuple_graphs;
|
*transformed = node_tuple_graphs;
|
||||||
|
@ -251,8 +255,8 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
||||||
|
|
||||||
// if the constant node is constant of vector of graph ,add graph to manager
|
// if the constant node is constant of vector of graph ,add graph to manager
|
||||||
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
|
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
|
||||||
(void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
|
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
|
||||||
&resolved_node);
|
&resolved_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
|
|
|
@ -17,6 +17,9 @@ from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, Model, context
|
from mindspore import Tensor, Model, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
from mindspore.nn import ReLU
|
from mindspore.nn import ReLU
|
||||||
from ...ut_filter import non_graph_engine
|
from ...ut_filter import non_graph_engine
|
||||||
|
@ -66,3 +69,58 @@ def function_access_base(number):
|
||||||
def test_access_0040():
|
def test_access_0040():
|
||||||
""" test_access_0040 """
|
""" test_access_0040 """
|
||||||
function_access_base(2)
|
function_access_base(2)
|
||||||
|
|
||||||
|
|
||||||
|
class OpSeqNet(Cell):
|
||||||
|
def __init__(self, loop_count=1):
|
||||||
|
super().__init__()
|
||||||
|
self.loop_count = loop_count
|
||||||
|
self.op_seq = (P.Sqrt(), P.Reciprocal(), P.Square())
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
t = x
|
||||||
|
for op in self.op_seq:
|
||||||
|
t = op(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_seq_test():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = OpSeqNet()
|
||||||
|
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||||
|
input_me = Tensor(input_np)
|
||||||
|
net(input_me)
|
||||||
|
|
||||||
|
|
||||||
|
_grad_fusion = C.MultitypeFuncGraph("grad_fushion")
|
||||||
|
|
||||||
|
|
||||||
|
@_grad_fusion.register("Tensor", "Function")
|
||||||
|
def tensor_grad_scale(x, op):
|
||||||
|
return op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceTest(Cell):
|
||||||
|
def __init__(self, loop_count=1):
|
||||||
|
super().__init__()
|
||||||
|
self.op_list = ()
|
||||||
|
self.fushion_flag = [0, 1, 1, 0, 1, 0]
|
||||||
|
for i in self.fushion_flag:
|
||||||
|
op = P.AllReduce().add_prim_attr('fusion', i)
|
||||||
|
self.op_list = self.op_list + (op,)
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
ret = ()
|
||||||
|
for _ in self.fushion_flag:
|
||||||
|
ret = ret + (x,)
|
||||||
|
fushion_res = self.hyper_map(F.partial(_grad_fusion), ret, self.op_list)
|
||||||
|
return fushion_res
|
||||||
|
|
||||||
|
|
||||||
|
def test_allreduce_fushio_test():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = AllReduceTest()
|
||||||
|
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||||
|
input_me = Tensor(input_np)
|
||||||
|
net(input_me)
|
||||||
|
|
Loading…
Reference in New Issue