forked from mindspore-Ecosystem/mindspore
Fix the problem when use Parameter as top func graph arguments.
This commit is contained in:
parent
b7ec631f11
commit
bbfe665977
|
@ -316,8 +316,8 @@ ResolveIRPassLib::ResolveIRPassLib() {
|
|||
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
meta_fg_var_prepare_ = MakeSubstitution(std::make_shared<MetaFgVarPrepare>(), "meta_fg_var_prepare", IsCNode);
|
||||
MetaUnpackPrepareLib::MetaUnpackPrepareLib() {
|
||||
meta_unpack_prepare_ = MakeSubstitution(std::make_shared<MetaFgVarPrepare>(), "meta_unpack_prepare", IsCNode);
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -200,11 +200,11 @@ class ResolveIRPassLib {
|
|||
SubstitutionPtr resolver_;
|
||||
};
|
||||
|
||||
class InferenceOptPrepareLib {
|
||||
class MetaUnpackPrepareLib {
|
||||
public:
|
||||
InferenceOptPrepareLib();
|
||||
~InferenceOptPrepareLib() = default;
|
||||
SubstitutionPtr meta_fg_var_prepare_;
|
||||
MetaUnpackPrepareLib();
|
||||
~MetaUnpackPrepareLib() = default;
|
||||
SubstitutionPtr meta_unpack_prepare_;
|
||||
};
|
||||
|
||||
// predicate functions
|
||||
|
|
|
@ -45,25 +45,22 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> nodes;
|
||||
AnfNodePtr unpack_graph_node = nullptr;
|
||||
std::shared_ptr<prim::UnpackGraphPrimitive> unpack_graph;
|
||||
size_t inputs_begin_index;
|
||||
if (is_unpack) {
|
||||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, true);
|
||||
nodes.push_back(NewValueNode(unpack_graph));
|
||||
nodes.push_back(func_node);
|
||||
unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, true);
|
||||
// {unpackcall, {GradOperation, ...}, args...} and other {unpackcall, {meta_fg_opration, ...}, args...}
|
||||
const size_t inputs_begin_index = 2;
|
||||
(void)std::transform(inputs_y.begin() + inputs_begin_index, inputs_y.end(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr &node) { return node; });
|
||||
unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes);
|
||||
inputs_begin_index = 2;
|
||||
} else {
|
||||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, false);
|
||||
nodes.push_back(NewValueNode(unpack_graph));
|
||||
nodes.push_back(func_node);
|
||||
unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, false);
|
||||
// {{GradOperation, ...}, args...} and other {{meta_fg_opration, ...}, args...}
|
||||
const size_t inputs_begin_index = 1;
|
||||
(void)std::transform(inputs_y.cbegin() + SizeToLong(inputs_begin_index), inputs_y.cend(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr &node) { return node; });
|
||||
unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes);
|
||||
inputs_begin_index = 1;
|
||||
}
|
||||
(void)nodes.emplace_back(NewValueNode(unpack_graph));
|
||||
(void)nodes.emplace_back(func_node);
|
||||
(void)std::transform(inputs_y.cbegin() + SizeToLong(inputs_begin_index), inputs_y.cend(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr &node) { return node; });
|
||||
unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes);
|
||||
return unpack_graph_node;
|
||||
}
|
||||
|
||||
|
@ -90,8 +87,8 @@ bool CheckMetaFgOps(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// {{GradOperation, g, w}, Ys}, {UnPackCall, {GradOperation, g, w}, Ys},
|
||||
// and other {{meta_fg_opration, ...}, ...} or {UnPackCall, {meta_fg_opration, ...}, ...}
|
||||
// {{GradOperation, g, w}, Ys}, {UnpackCall, {GradOperation, g, w}, Ys},
|
||||
// and other {{meta_fg_opration, ...}, ...} or {UnpackCall, {meta_fg_opration, ...}, ...}
|
||||
AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
|
|
|
@ -297,8 +297,8 @@ FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_g
|
|||
#ifdef ENABLE_PROFILE
|
||||
double t2 = GetTime();
|
||||
#endif
|
||||
auto ret = ProgramSpecialize(resource, func_graph, result.context);
|
||||
resource->set_func_graph(ret);
|
||||
auto res = ProgramSpecialize(resource, func_graph, result.context);
|
||||
resource->set_func_graph(res);
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t3 = GetTime();
|
||||
MsProfile::StatTime("renormalize.infer", t2 - t1);
|
||||
|
@ -307,7 +307,7 @@ FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_g
|
|||
|
||||
MS_LOG(DEBUG) << "Renormalize end";
|
||||
|
||||
return ret;
|
||||
return res;
|
||||
}
|
||||
|
||||
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &resource) {
|
||||
|
@ -610,29 +610,29 @@ bool OrderEnforceAction(const ResourcePtr &resource) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool InferenceOptPrepareAction(const ResourcePtr &resource) {
|
||||
bool MetaUnpackPrepareAction(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
if (resource->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
|
||||
MS_LOG(EXCEPTION) << "MetaUnpackPrepareAction error, manager is null.";
|
||||
}
|
||||
if (resource->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
|
||||
MS_LOG(EXCEPTION) << "MetaUnpackPrepareAction error, graph is null.";
|
||||
}
|
||||
return InferenceOptPreparePass(resource);
|
||||
return MetaUnpackPreparePass(resource);
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
abstract::AbstractBasePtrList args_abs = resource->args_abs();
|
||||
auto arguments = resource->arguments();
|
||||
|
||||
// Parallel checking.
|
||||
auto context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
|
||||
context->ParallelParameterContextInitShape(func_graph);
|
||||
|
||||
// Suppose that there is not KeywordArgument for the top graph
|
||||
// get the hyper parameter
|
||||
// Handle the Parameter from FV inputs.
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
|
@ -643,10 +643,33 @@ abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
|
|||
auto ref_key = std::make_shared<RefKey>(param_node->name());
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRefTensor>(abs_value, ref_key);
|
||||
context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
|
||||
args_abs.push_back(abs_ref);
|
||||
(void)args_abs.emplace_back(abs_ref);
|
||||
context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
|
||||
}
|
||||
}
|
||||
// Handle the Parameter from input arguments.
|
||||
auto arg_size = arguments.size();
|
||||
for (size_t i = 0; i < arg_size; ++i) {
|
||||
auto param_value = dyn_cast<tensor::MetaTensor>(arguments[i]);
|
||||
if (param_value == nullptr || !param_value->is_parameter()) {
|
||||
continue;
|
||||
}
|
||||
const auto ¶m = func_graph->parameters()[i];
|
||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->has_default()) {
|
||||
continue;
|
||||
}
|
||||
// The argument is Parameter.
|
||||
MS_LOG(DEBUG) << "Meet an argument of Parameter, value: " << param_value->ToString()
|
||||
<< ", parameter: " << param->ToString() << ", has_default: " << param_node->has_default();
|
||||
param_node->set_default_param(param_value);
|
||||
// Update the i-th arguments' abstract.
|
||||
auto abs_value = param_value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
|
||||
auto ref_key = std::make_shared<RefKey>(param_node->name());
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRefTensor>(abs_value, ref_key);
|
||||
args_abs[i] = abs_ref;
|
||||
}
|
||||
return args_abs;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1504,7 +1527,7 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||
}
|
||||
|
||||
(void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||
(void)actions.emplace_back(std::make_pair("meta_unpack_prepare", MetaUnpackPrepareAction));
|
||||
// Evaluate type and shape, and specialize.
|
||||
(void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||
// Auto-monad for side-effects handling.
|
||||
|
|
|
@ -584,10 +584,10 @@ OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPa
|
|||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetInferenceOptPreparePhases() {
|
||||
opt::irpass::InferenceOptPrepareLib irpass;
|
||||
auto meta_fg_var_prepare = opt::OptPassConfig({irpass.meta_fg_var_prepare_});
|
||||
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", meta_fg_var_prepare}});
|
||||
OptPassGroupMap GetMetaUnpackPreparePhases() {
|
||||
opt::irpass::MetaUnpackPrepareLib irpass;
|
||||
auto meta_unpack_prepare = opt::OptPassConfig({irpass.meta_unpack_prepare_});
|
||||
opt::OptPassGroupMap prepare_map({{"meta_unpack_prepare", meta_unpack_prepare}});
|
||||
return prepare_map;
|
||||
}
|
||||
|
||||
|
@ -797,11 +797,11 @@ bool ValidatePass(const ResourcePtr &resource) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool InferenceOptPreparePass(const ResourcePtr &resource) {
|
||||
bool MetaUnpackPreparePass(const ResourcePtr &resource) {
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto prepare_map = GetInferenceOptPreparePhases();
|
||||
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", resource, prepare_map);
|
||||
auto prepare_map = GetMetaUnpackPreparePhases();
|
||||
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("meta_unpack_prepare", resource, prepare_map);
|
||||
(void)infer_opt_prepare->step(func_graph, false);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ bool ValidatePass(const ResourcePtr &resource);
|
|||
bool GeSpecializedPass(const ResourcePtr &resource);
|
||||
bool ConvertPrepareAdapt(const ResourcePtr &resource);
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &resource);
|
||||
bool InferenceOptPreparePass(const ResourcePtr &resource);
|
||||
bool MetaUnpackPreparePass(const ResourcePtr &resource);
|
||||
void ReclaimOptimizer();
|
||||
bool PynativeOptPass(const ResourcePtr &resource);
|
||||
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &resource);
|
||||
|
|
|
@ -295,9 +295,11 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::tuple &args, bool ena
|
|||
SetValueMutable(converted->ToAbstract());
|
||||
set_mutable = true;
|
||||
}
|
||||
AbstractBasePtr ptr = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
|
||||
args_abs.push_back(ptr);
|
||||
(void)cur_convert_input_.emplace(args[i].ptr(), ptr);
|
||||
AbstractBasePtr abs = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
|
||||
(void)args_abs.emplace_back(abs);
|
||||
// The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
|
||||
// so we keep all inputs for subsequent procedure.
|
||||
(void)cur_convert_input_.emplace(args[i].ptr(), std::make_pair(converted, abs));
|
||||
}
|
||||
|
||||
// If cache matched no need CheckArgsValid
|
||||
|
@ -851,6 +853,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
|
||||
// Get the parameters items and add the value to args_abs.
|
||||
abstract::AbstractBasePtrList args_abs;
|
||||
std::vector<ValuePtr> arguments;
|
||||
std::size_t size = args.size();
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
ValuePtr converted = nullptr;
|
||||
|
@ -858,15 +861,18 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
// So can't use cur_convert_input_ directly.
|
||||
auto iter = cur_convert_input_.find(args[i].ptr());
|
||||
if (iter != cur_convert_input_.end()) {
|
||||
args_abs.push_back(iter->second);
|
||||
(void)arguments.emplace_back(iter->second.first);
|
||||
(void)args_abs.emplace_back(iter->second.second);
|
||||
continue;
|
||||
}
|
||||
bool succ = parse::ConvertData(args[i], &converted);
|
||||
if (!succ) {
|
||||
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
|
||||
}
|
||||
args_abs.push_back(ArgsToAbstract(converted, enable_tuple_broaden_));
|
||||
(void)arguments.emplace_back(converted);
|
||||
(void)args_abs.emplace_back(ArgsToAbstract(converted, enable_tuple_broaden_));
|
||||
}
|
||||
resource->set_arguments(arguments);
|
||||
resource->set_args_abs(args_abs);
|
||||
executor_info->arg_list_size = size;
|
||||
executor_info->resource = resource;
|
||||
|
@ -897,7 +903,7 @@ std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionI
|
|||
MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'";
|
||||
std::vector<ActionItem> filtered_actions;
|
||||
for (const auto &item : actions) {
|
||||
filtered_actions.emplace_back(item);
|
||||
(void)filtered_actions.emplace_back(item);
|
||||
if (item.first == "validate") {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -159,7 +159,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
py::list compile_cache_dep_files_;
|
||||
bool compile_cache_consistent_{true};
|
||||
py::dict weights_;
|
||||
std::map<PyObject *, AbstractBasePtr> cur_convert_input_;
|
||||
std::map<PyObject *, std::pair<ValuePtr, AbstractBasePtr>> cur_convert_input_;
|
||||
};
|
||||
using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;
|
||||
|
||||
|
|
|
@ -86,6 +86,9 @@ class Resource : public ResourceBase {
|
|||
const abstract::AbstractBasePtrList &args_abs() const { return args_abs_; }
|
||||
void set_args_abs(const abstract::AbstractBasePtrList &args_abs) { args_abs_ = args_abs; }
|
||||
|
||||
const std::vector<ValuePtr> &arguments() const { return arguments_; }
|
||||
void set_arguments(const std::vector<ValuePtr> &arguments) { arguments_ = arguments; }
|
||||
|
||||
void set_vm_loop(const bool &flag, const int64_t size) {
|
||||
vm_loop_flag_ = flag;
|
||||
loop_size_ = size;
|
||||
|
@ -122,6 +125,9 @@ class Resource : public ResourceBase {
|
|||
abstract::AnalysisEnginePtr engine_;
|
||||
FuncGraphPtr func_graph_;
|
||||
FuncGraphPtr optimize_graph_;
|
||||
// The arguments may contain a Parameter, we need connect it to the Parameter default value of func graph.
|
||||
// We keep all arguments inputs here for subsequent procedure.
|
||||
std::vector<ValuePtr> arguments_;
|
||||
abstract::AbstractBasePtrList args_abs_;
|
||||
// The source obj to compile, usually a `Cell` or `ms_function` decorated function.
|
||||
py::object source_input_;
|
||||
|
|
|
@ -385,3 +385,30 @@ def test_parameter_same_name_between_tuple_or_list():
|
|||
output = net(x)
|
||||
output_expect = Tensor(20, ms.float32)
|
||||
assert output == output_expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_argument_and_fv():
|
||||
"""
|
||||
Feature: Parameter argmument in top func graph.
|
||||
Description: Use Parameter as input argmument.
|
||||
Expectation: Parameter used as argument should equal to used as FV.
|
||||
"""
|
||||
y = Parameter(Tensor([1]))
|
||||
class Demo(Cell):
|
||||
def construct(self, x):
|
||||
ms.ops.Assign()(x, Tensor([0]))
|
||||
ms.ops.Assign()(y, Tensor([0]))
|
||||
return True
|
||||
|
||||
x = Parameter(Tensor([1]))
|
||||
net = Demo()
|
||||
net(x)
|
||||
print(Tensor(x))
|
||||
print(Tensor(y))
|
||||
assert x == y
|
|
@ -137,9 +137,9 @@ def test_two_net():
|
|||
print("res2:", res2)
|
||||
|
||||
|
||||
class OutNet_1(nn.Cell):
|
||||
class OutNet1(nn.Cell):
|
||||
def __init__(self, net1, net2):
|
||||
super(OutNet_1, self).__init__()
|
||||
super(OutNet1, self).__init__()
|
||||
self.param1 = ParameterTuple(net1.get_parameters())
|
||||
self.param2 = ParameterTuple(net2.get_parameters())
|
||||
|
||||
|
@ -160,14 +160,14 @@ def test_inner_out_net_1():
|
|||
with pytest.raises(RuntimeError, match="its name 'name_a' already exists."):
|
||||
net1 = InnerNet()
|
||||
net2 = InnerNet()
|
||||
out_net = OutNet_1(net1, net2)
|
||||
out_net = OutNet1(net1, net2)
|
||||
res = out_net(Tensor([1], ms.float32))
|
||||
print("res:", res)
|
||||
|
||||
|
||||
class OutNet_2(nn.Cell):
|
||||
class OutNet2(nn.Cell):
|
||||
def __init__(self, net1, net2):
|
||||
super(OutNet_2, self).__init__()
|
||||
super(OutNet2, self).__init__()
|
||||
self.cell_list = nn.CellList()
|
||||
self.cell_list.append(net1)
|
||||
self.cell_list.append(net2)
|
||||
|
@ -190,6 +190,6 @@ def test_inner_out_net_2():
|
|||
"""
|
||||
net1 = InnerNet()
|
||||
net2 = InnerNet()
|
||||
out_net = OutNet_2(net1, net2)
|
||||
out_net = OutNet2(net1, net2)
|
||||
res = out_net(Tensor([1], ms.float32))
|
||||
print("res:", res)
|
|
@ -411,8 +411,8 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
|
|||
}
|
||||
|
||||
TEST_F(TestComposite, test_UnpackCall_3args) {
|
||||
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
|
||||
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
|
||||
MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
|
||||
FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 3);
|
||||
|
||||
auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||
|
@ -429,7 +429,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
|
|||
|
||||
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
|
||||
AbstractTuplePtr ret =
|
||||
dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
|
||||
dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -439,8 +439,8 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
|
|||
}
|
||||
|
||||
TEST_F(TestComposite, test_UnpackCall_5args) {
|
||||
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
|
||||
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
|
||||
MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
|
||||
FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 5);
|
||||
|
||||
auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||
|
@ -457,7 +457,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
|
|||
|
||||
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
|
||||
AbstractTuplePtr ret =
|
||||
dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
|
||||
dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue