forked from mindspore-Ecosystem/mindspore
Fix partial primitive poly node
This commit is contained in:
parent
703c1b26dd
commit
ffa3352088
|
@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
|
|||
}
|
||||
auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
|
||||
|
||||
if (func->context() != nullptr) {
|
||||
if (!IsVisible(func_graph_, func->context()->func_graph())) {
|
||||
MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
||||
}
|
||||
} else {
|
||||
if (func->context() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
||||
}
|
||||
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
|
||||
|
@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
// First element is partial, second is func so arg is start from 2
|
||||
(void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
|
||||
func = inputs[1];
|
||||
new_inputs = args;
|
||||
(void)new_inputs.insert(new_inputs.begin(), func);
|
||||
}
|
||||
new_inputs = args;
|
||||
(void)new_inputs.insert(new_inputs.begin(), func);
|
||||
|
||||
AbstractBasePtrList argvals;
|
||||
MS_EXCEPTION_IF_NULL(new_inputs[0]);
|
||||
|
@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
|
||||
}
|
||||
|
||||
if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) {
|
||||
auto wrapped_node = BuildSpecializedParameterNode(new_node);
|
||||
new_inputs[0] = wrapped_node;
|
||||
if (!func->isa<ValueNode>()) {
|
||||
MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
|
||||
if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
|
||||
auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
|
||||
EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
|
||||
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
|
||||
AbstractBasePtrList empty_args;
|
||||
auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
|
||||
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
|
||||
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
|
||||
if (status == kSpecializeFindUniqueArgvalPoly ||
|
||||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
|
||||
func->abstract()->isa<PartialAbstractClosure>()))) {
|
||||
auto wrapped_node = BuildSpecializedParameterNode(new_node);
|
||||
new_inputs[0] = wrapped_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (CanSpecializeNode(func)) {
|
||||
|
|
|
@ -14,9 +14,12 @@
|
|||
# ============================================================================
|
||||
""" test nn ops """
|
||||
import numpy as np
|
||||
from numpy.random import normal
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops.composite import core
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
|
|||
net(t1, t2)
|
||||
|
||||
|
||||
# test free variable function list as parameter
|
||||
def test_remove_and_fv_2():
|
||||
@core(loop_can_uroll=True)
|
||||
def inner_loop(x, input_data, fv_func_list):
|
||||
ret = ()
|
||||
for fv_fn in fv_func_list:
|
||||
ele = fv_fn(input_data)
|
||||
ret += (ele,)
|
||||
return ret
|
||||
|
||||
@ms_function
|
||||
def out_loop(input1, input_data):
|
||||
ret = ()
|
||||
|
||||
def fv_func1(y):
|
||||
return input1 * y
|
||||
def fv_func2(y):
|
||||
return input1 - y
|
||||
fv_func_list = [fv_func1, fv_func2]
|
||||
ele0 = inner_loop(input1, input_data[0], fv_func_list)
|
||||
ele1 = inner_loop(input1, input_data[1], fv_func_list)
|
||||
ret = (ele0, ele1)
|
||||
return ret
|
||||
|
||||
input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1))))
|
||||
input1 = Tensor(normal(0, 0.1, (3, 3)))
|
||||
out_loop(input1, input_data)
|
||||
|
||||
|
||||
# test cell as high order argument
|
||||
# The graph with free variables used as argument is not supported yet
|
||||
# because of the limit of inference specialize system
|
||||
def Xtest_conv2d_op_with_arg():
|
||||
def test_conv2d_op_with_argi_1():
|
||||
class Conv2dNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Conv2dNet, self).__init__()
|
||||
|
@ -279,7 +311,7 @@ def test_op_with_arg_as_input():
|
|||
|
||||
# The partial application used as argument is not supported yet
|
||||
# because of the limit of inference specialize system
|
||||
def Xtest_partial_as_arg():
|
||||
def test_partial_as_arg():
|
||||
class PartialArgNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PartialArgNet, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue