forked from mindspore-Ecosystem/mindspore
!29628 Add monad parameter.
Merge pull request !29628 from gaoyong10/runtime_second12
This commit is contained in:
commit
0ddc057820
|
@ -145,6 +145,8 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
|
|||
} else if (value->isa<Int32Imm>()) {
|
||||
type = kNumberTypeInt32;
|
||||
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
|
||||
} else if (value->isa<Monad>()) {
|
||||
return;
|
||||
} else {
|
||||
std::string error_info = "Invalid value:" + value->ToString();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
|
|
@ -28,26 +28,12 @@ namespace {
|
|||
// Check if node is a value node need to create a device tensor.
|
||||
bool IsFrontValueNode(const KernelWithIndex &node_with_index) {
|
||||
const auto &node = node_with_index.first;
|
||||
size_t index = node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(node)) {
|
||||
return !HasAbstractMonad(node);
|
||||
}
|
||||
|
||||
const auto &abstract = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.size() <= index) {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(sub_abstracts[index]);
|
||||
return !sub_abstracts[index]->isa<abstract::AbstractMonad>();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Fetch real input node in maketuple.
|
||||
|
@ -421,11 +407,22 @@ void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGrap
|
|||
}
|
||||
}
|
||||
|
||||
bool isValidMonadNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->isa<ValueNode>() || node->isa<Parameter>() || AnfAlgo::IsCallNode(node);
|
||||
}
|
||||
|
||||
// Fetch all inputs of node.
|
||||
std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (HasAbstractMonad(node)) {
|
||||
return {};
|
||||
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||
const auto &real_node = real_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
if (isValidMonadNode(real_node)) {
|
||||
return {real_node_with_index};
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid monad node:" << real_node->DebugString();
|
||||
}
|
||||
|
||||
// The node is divided into the following types:
|
||||
|
@ -436,10 +433,11 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
size_t real_index = node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
std::vector<KernelWithIndex> results;
|
||||
// 2. MakeTuple.
|
||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeCSRTensor) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeCOOTensor)) {
|
||||
|
||||
// 2. Tuple node.
|
||||
const PrimitiveSet expand_prims{prim::kPrimMakeTuple, prim::kPrimMakeCSRTensor, prim::kPrimMakeCOOTensor};
|
||||
// The MakeTuple/MakeSparse node need expand and recurse.
|
||||
if (IsOneOfPrimitiveCNode(real_node, expand_prims)) {
|
||||
const auto &cnode = real_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
|
||||
|
@ -518,17 +516,6 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
for (size_t i = 0; i < output_num; ++i) {
|
||||
(void)results.emplace_back(real_node, i);
|
||||
}
|
||||
if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
for (const auto &sub_abstract : sub_abstracts) {
|
||||
MS_EXCEPTION_IF_NULL(sub_abstract);
|
||||
if (sub_abstract->isa<abstract::AbstractMonad>()) {
|
||||
(void)results.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
|
@ -692,10 +679,6 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
|||
|
||||
for (size_t i = input_start_pos; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
// skip monad node.
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
const auto &sub_results = FetchInputNodeByNode(inputs[i]);
|
||||
(void)results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||
}
|
||||
|
@ -1288,9 +1271,9 @@ NodeWithContext ControlNodeParser::FetchBackendParameterWithContextByFrontParame
|
|||
if (AnfAlgo::GetOutputTensorMemSize(node_with_context.first, 0) != 0) {
|
||||
return node_with_context;
|
||||
}
|
||||
MS_LOG(WARNING) << "Backend node:" << node_with_context.first->DebugString()
|
||||
<< " for front node:" << front_parameter_with_index.first->DebugString()
|
||||
<< " index:" << front_parameter_with_index.second << " output size is 0.";
|
||||
MS_LOG(DEBUG) << "Backend node:" << node_with_context.first->DebugString()
|
||||
<< " for front node:" << front_parameter_with_index.first->DebugString()
|
||||
<< " index:" << front_parameter_with_index.second << " output size is 0.";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
@ -1366,9 +1349,6 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr>
|
|||
for (int i = SizeToInt(inputs.size()) - 1, j = SizeToInt(parameters.size()) - 1; i >= 1 && j >= 0; --i, --j) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[IntToSize(i)]);
|
||||
MS_EXCEPTION_IF_NULL(parameters[IntToSize(j)]);
|
||||
if (HasAbstractMonad(inputs[IntToSize(i)])) {
|
||||
continue;
|
||||
}
|
||||
AddFormalToRealParameter(parameters[IntToSize(j)], inputs[IntToSize(i)], call_node_to_func_graphs_,
|
||||
&formal_to_real_parameters);
|
||||
}
|
||||
|
@ -1400,9 +1380,6 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr>
|
|||
for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
MS_EXCEPTION_IF_NULL(parameters[i - kPartialInputStartPos]);
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
AddFormalToRealParameter(parameters[i - kPartialInputStartPos], inputs[i], call_node_to_func_graphs_,
|
||||
&formal_to_real_parameters);
|
||||
}
|
||||
|
@ -1521,9 +1498,6 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
|
|||
const auto &graph = graphs[i];
|
||||
auto device_context = device_contexts[i];
|
||||
for (const auto ¶meter : graph->input_nodes()) {
|
||||
if (HasAbstractMonad(parameter)) {
|
||||
continue;
|
||||
}
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
|
||||
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
|
||||
const auto &front_tuple_parameter_with_index = graph->GetElementInTupleBackendFrontIndexMap(parameter);
|
||||
|
@ -1798,28 +1772,33 @@ void ControlNodeParser::ParseUnRecursionCallNode() {
|
|||
}
|
||||
}
|
||||
|
||||
bool ControlNodeParser::IsCallNodeNeedStack(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto input_with_indexs = FetchInputNodeByCNode(node);
|
||||
for (const auto &input_with_index : input_with_indexs) {
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
// If the call node has call or recursion graph input, a stack created for the call node is required.
|
||||
if (!AnfAlgo::IsCallNode(input_with_index.first)) {
|
||||
if (!input_with_index.first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
const auto &graph = FetchKernelGraphByFrontNode(input_with_index.first);
|
||||
if (graph == nullptr || (!IsRecursionKernelGraph(graph))) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(control_node);
|
||||
if (!AnfAlgo::IsCallNode(control_node)) {
|
||||
continue;
|
||||
}
|
||||
auto input_with_indexs = FetchInputNodeByCNode(control_node);
|
||||
for (const auto &input_with_index : input_with_indexs) {
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
// If the call node has call or recursion graph input, a stack created for the call node is required.
|
||||
if (!AnfAlgo::IsCallNode(input_with_index.first)) {
|
||||
if (!input_with_index.first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
const auto &graph = FetchKernelGraphByFrontNode(input_with_index.first);
|
||||
if (graph == nullptr || (!IsRecursionKernelGraph(graph))) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::IsCallNode(control_node) && IsCallNodeNeedStack(control_node)) {
|
||||
(void)need_stack_control_nodes_.emplace(control_node);
|
||||
MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1841,7 +1820,8 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr>
|
|||
MS_LOG(EXCEPTION) << "Invalid return node:" << control_node->DebugString();
|
||||
}
|
||||
|
||||
if (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend))) {
|
||||
if ((!IsInputInSameLevel(control_node)) ||
|
||||
(call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend)))) {
|
||||
(void)need_stack_control_nodes_.emplace(control_node);
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) ||
|
||||
|
|
|
@ -227,6 +227,7 @@ class ControlNodeParser {
|
|||
// Get the control nodes and kernel graphs which need to add a stack actor for them.
|
||||
// When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
|
||||
void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
bool IsCallNodeNeedStack(const AnfNodePtr &node);
|
||||
void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
|
||||
// Parse the level of inputs and outputs of graphs and all control nodes.
|
||||
void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
|
|
@ -218,9 +218,6 @@ std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(const Gra
|
|||
// The entrance actor has two parts of node members :
|
||||
// 1. The formal parameters of the subgraph are used to connect the actor's output arrows.
|
||||
for (const auto ¶meter : func_graph->parameters()) {
|
||||
if (HasAbstractMonad(parameter)) {
|
||||
continue;
|
||||
}
|
||||
const auto &abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
|
|
Loading…
Reference in New Issue