Fix the problem when use Parameter as top func graph arguments.

This commit is contained in:
Zhang Qinghua 2022-07-27 20:12:55 +08:00
parent b7ec631f11
commit bbfe665977
14 changed files with 119 additions and 60 deletions

View File

@ -316,8 +316,8 @@ ResolveIRPassLib::ResolveIRPassLib() {
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {
meta_fg_var_prepare_ = MakeSubstitution(std::make_shared<MetaFgVarPrepare>(), "meta_fg_var_prepare", IsCNode);
MetaUnpackPrepareLib::MetaUnpackPrepareLib() {
meta_unpack_prepare_ = MakeSubstitution(std::make_shared<MetaFgVarPrepare>(), "meta_unpack_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt

View File

@ -200,11 +200,11 @@ class ResolveIRPassLib {
SubstitutionPtr resolver_;
};
class InferenceOptPrepareLib {
class MetaUnpackPrepareLib {
public:
InferenceOptPrepareLib();
~InferenceOptPrepareLib() = default;
SubstitutionPtr meta_fg_var_prepare_;
MetaUnpackPrepareLib();
~MetaUnpackPrepareLib() = default;
SubstitutionPtr meta_unpack_prepare_;
};
// predicate functions

View File

@ -45,25 +45,22 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> nodes;
AnfNodePtr unpack_graph_node = nullptr;
std::shared_ptr<prim::UnpackGraphPrimitive> unpack_graph;
size_t inputs_begin_index;
if (is_unpack) {
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, true);
nodes.push_back(NewValueNode(unpack_graph));
nodes.push_back(func_node);
unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, true);
// {unpackcall, {GradOperation, ...}, args...} and other {unpackcall, {meta_fg_opration, ...}, args...}
const size_t inputs_begin_index = 2;
(void)std::transform(inputs_y.begin() + inputs_begin_index, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes);
inputs_begin_index = 2;
} else {
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, false);
nodes.push_back(NewValueNode(unpack_graph));
nodes.push_back(func_node);
unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, false);
// {{GradOperation, ...}, args...} and other {{meta_fg_opration, ...}, args...}
const size_t inputs_begin_index = 1;
(void)std::transform(inputs_y.cbegin() + SizeToLong(inputs_begin_index), inputs_y.cend(), std::back_inserter(nodes),
[](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes);
inputs_begin_index = 1;
}
(void)nodes.emplace_back(NewValueNode(unpack_graph));
(void)nodes.emplace_back(func_node);
(void)std::transform(inputs_y.cbegin() + SizeToLong(inputs_begin_index), inputs_y.cend(), std::back_inserter(nodes),
[](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes);
return unpack_graph_node;
}
@ -90,8 +87,8 @@ bool CheckMetaFgOps(const AnfNodePtr &node) {
return false;
}
// {{GradOperation, g, w}, Ys}, {UnPackCall, {GradOperation, g, w}, Ys},
// and other {{meta_fg_opration, ...}, ...} or {UnPackCall, {meta_fg_opration, ...}, ...}
// {{GradOperation, g, w}, Ys}, {UnpackCall, {GradOperation, g, w}, Ys},
// and other {{meta_fg_opration, ...}, ...} or {UnpackCall, {meta_fg_opration, ...}, ...}
AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {

View File

@ -297,8 +297,8 @@ FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_g
#ifdef ENABLE_PROFILE
double t2 = GetTime();
#endif
auto ret = ProgramSpecialize(resource, func_graph, result.context);
resource->set_func_graph(ret);
auto res = ProgramSpecialize(resource, func_graph, result.context);
resource->set_func_graph(res);
#ifdef ENABLE_PROFILE
double t3 = GetTime();
MsProfile::StatTime("renormalize.infer", t2 - t1);
@ -307,7 +307,7 @@ FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_g
MS_LOG(DEBUG) << "Renormalize end";
return ret;
return res;
}
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &resource) {
@ -610,29 +610,29 @@ bool OrderEnforceAction(const ResourcePtr &resource) {
return true;
}
bool InferenceOptPrepareAction(const ResourcePtr &resource) {
bool MetaUnpackPrepareAction(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
if (resource->manager() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
MS_LOG(EXCEPTION) << "MetaUnpackPrepareAction error, manager is null.";
}
if (resource->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
MS_LOG(EXCEPTION) << "MetaUnpackPrepareAction error, graph is null.";
}
return InferenceOptPreparePass(resource);
return MetaUnpackPreparePass(resource);
}
namespace {
abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
FuncGraphPtr func_graph = resource->func_graph();
abstract::AbstractBasePtrList args_abs = resource->args_abs();
auto arguments = resource->arguments();
// Parallel checking.
auto context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
context->ParallelParameterContextInitShape(func_graph);
// Suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
// Handle the Parameter from FV inputs.
for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_node);
@ -643,10 +643,33 @@ abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref = std::make_shared<abstract::AbstractRefTensor>(abs_value, ref_key);
context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
args_abs.push_back(abs_ref);
(void)args_abs.emplace_back(abs_ref);
context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
}
}
// Handle the Parameter from input arguments.
auto arg_size = arguments.size();
for (size_t i = 0; i < arg_size; ++i) {
auto param_value = dyn_cast<tensor::MetaTensor>(arguments[i]);
if (param_value == nullptr || !param_value->is_parameter()) {
continue;
}
const auto &param = func_graph->parameters()[i];
auto param_node = std::static_pointer_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->has_default()) {
continue;
}
// The argument is Parameter.
MS_LOG(DEBUG) << "Meet an argument of Parameter, value: " << param_value->ToString()
<< ", parameter: " << param->ToString() << ", has_default: " << param_node->has_default();
param_node->set_default_param(param_value);
// Update the i-th arguments' abstract.
auto abs_value = param_value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref = std::make_shared<abstract::AbstractRefTensor>(abs_value, ref_key);
args_abs[i] = abs_ref;
}
return args_abs;
}
} // namespace
@ -1504,7 +1527,7 @@ static std::vector<ActionItem> CommonPipeline() {
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}
(void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
(void)actions.emplace_back(std::make_pair("meta_unpack_prepare", MetaUnpackPrepareAction));
// Evaluate type and shape, and specialize.
(void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
// Auto-monad for side-effects handling.

View File

@ -584,10 +584,10 @@ OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPa
return map;
}
OptPassGroupMap GetInferenceOptPreparePhases() {
opt::irpass::InferenceOptPrepareLib irpass;
auto meta_fg_var_prepare = opt::OptPassConfig({irpass.meta_fg_var_prepare_});
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", meta_fg_var_prepare}});
OptPassGroupMap GetMetaUnpackPreparePhases() {
opt::irpass::MetaUnpackPrepareLib irpass;
auto meta_unpack_prepare = opt::OptPassConfig({irpass.meta_unpack_prepare_});
opt::OptPassGroupMap prepare_map({{"meta_unpack_prepare", meta_unpack_prepare}});
return prepare_map;
}
@ -797,11 +797,11 @@ bool ValidatePass(const ResourcePtr &resource) {
return true;
}
bool InferenceOptPreparePass(const ResourcePtr &resource) {
bool MetaUnpackPreparePass(const ResourcePtr &resource) {
FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto prepare_map = GetInferenceOptPreparePhases();
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", resource, prepare_map);
auto prepare_map = GetMetaUnpackPreparePhases();
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("meta_unpack_prepare", resource, prepare_map);
(void)infer_opt_prepare->step(func_graph, false);
return true;
}

View File

@ -46,7 +46,7 @@ bool ValidatePass(const ResourcePtr &resource);
bool GeSpecializedPass(const ResourcePtr &resource);
bool ConvertPrepareAdapt(const ResourcePtr &resource);
bool AddCacheEmbeddingPass(const ResourcePtr &resource);
bool InferenceOptPreparePass(const ResourcePtr &resource);
bool MetaUnpackPreparePass(const ResourcePtr &resource);
void ReclaimOptimizer();
bool PynativeOptPass(const ResourcePtr &resource);
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &resource);

View File

@ -295,9 +295,11 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::tuple &args, bool ena
SetValueMutable(converted->ToAbstract());
set_mutable = true;
}
AbstractBasePtr ptr = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
args_abs.push_back(ptr);
(void)cur_convert_input_.emplace(args[i].ptr(), ptr);
AbstractBasePtr abs = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
(void)args_abs.emplace_back(abs);
// The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
// so we keep all inputs for subsequent procedure.
(void)cur_convert_input_.emplace(args[i].ptr(), std::make_pair(converted, abs));
}
// If cache matched no need CheckArgsValid
@ -851,6 +853,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
// Get the parameters items and add the value to args_abs.
abstract::AbstractBasePtrList args_abs;
std::vector<ValuePtr> arguments;
std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) {
ValuePtr converted = nullptr;
@ -858,15 +861,18 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
// So can't use cur_convert_input_ directly.
auto iter = cur_convert_input_.find(args[i].ptr());
if (iter != cur_convert_input_.end()) {
args_abs.push_back(iter->second);
(void)arguments.emplace_back(iter->second.first);
(void)args_abs.emplace_back(iter->second.second);
continue;
}
bool succ = parse::ConvertData(args[i], &converted);
if (!succ) {
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
}
args_abs.push_back(ArgsToAbstract(converted, enable_tuple_broaden_));
(void)arguments.emplace_back(converted);
(void)args_abs.emplace_back(ArgsToAbstract(converted, enable_tuple_broaden_));
}
resource->set_arguments(arguments);
resource->set_args_abs(args_abs);
executor_info->arg_list_size = size;
executor_info->resource = resource;
@ -897,7 +903,7 @@ std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionI
MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'";
std::vector<ActionItem> filtered_actions;
for (const auto &item : actions) {
filtered_actions.emplace_back(item);
(void)filtered_actions.emplace_back(item);
if (item.first == "validate") {
break;
}

View File

@ -159,7 +159,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
py::list compile_cache_dep_files_;
bool compile_cache_consistent_{true};
py::dict weights_;
std::map<PyObject *, AbstractBasePtr> cur_convert_input_;
std::map<PyObject *, std::pair<ValuePtr, AbstractBasePtr>> cur_convert_input_;
};
using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;

View File

@ -86,6 +86,9 @@ class Resource : public ResourceBase {
const abstract::AbstractBasePtrList &args_abs() const { return args_abs_; }
void set_args_abs(const abstract::AbstractBasePtrList &args_abs) { args_abs_ = args_abs; }
const std::vector<ValuePtr> &arguments() const { return arguments_; }
void set_arguments(const std::vector<ValuePtr> &arguments) { arguments_ = arguments; }
void set_vm_loop(const bool &flag, const int64_t size) {
vm_loop_flag_ = flag;
loop_size_ = size;
@ -122,6 +125,9 @@ class Resource : public ResourceBase {
abstract::AnalysisEnginePtr engine_;
FuncGraphPtr func_graph_;
FuncGraphPtr optimize_graph_;
// The arguments may contain a Parameter, we need connect it to the Parameter default value of func graph.
// We keep all arguments inputs here for subsequent procedure.
std::vector<ValuePtr> arguments_;
abstract::AbstractBasePtrList args_abs_;
// The source obj to compile, usually a `Cell` or `ms_function` decorated function.
py::object source_input_;

View File

@ -385,3 +385,30 @@ def test_parameter_same_name_between_tuple_or_list():
output = net(x)
output_expect = Tensor(20, ms.float32)
assert output == output_expect
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_argument_and_fv():
"""
Feature: Parameter argmument in top func graph.
Description: Use Parameter as input argmument.
Expectation: Parameter used as argument should equal to used as FV.
"""
y = Parameter(Tensor([1]))
class Demo(Cell):
def construct(self, x):
ms.ops.Assign()(x, Tensor([0]))
ms.ops.Assign()(y, Tensor([0]))
return True
x = Parameter(Tensor([1]))
net = Demo()
net(x)
print(Tensor(x))
print(Tensor(y))
assert x == y

View File

@ -137,9 +137,9 @@ def test_two_net():
print("res2:", res2)
class OutNet_1(nn.Cell):
class OutNet1(nn.Cell):
def __init__(self, net1, net2):
super(OutNet_1, self).__init__()
super(OutNet1, self).__init__()
self.param1 = ParameterTuple(net1.get_parameters())
self.param2 = ParameterTuple(net2.get_parameters())
@ -160,14 +160,14 @@ def test_inner_out_net_1():
with pytest.raises(RuntimeError, match="its name 'name_a' already exists."):
net1 = InnerNet()
net2 = InnerNet()
out_net = OutNet_1(net1, net2)
out_net = OutNet1(net1, net2)
res = out_net(Tensor([1], ms.float32))
print("res:", res)
class OutNet_2(nn.Cell):
class OutNet2(nn.Cell):
def __init__(self, net1, net2):
super(OutNet_2, self).__init__()
super(OutNet2, self).__init__()
self.cell_list = nn.CellList()
self.cell_list.append(net1)
self.cell_list.append(net2)
@ -190,6 +190,6 @@ def test_inner_out_net_2():
"""
net1 = InnerNet()
net2 = InnerNet()
out_net = OutNet_2(net1, net2)
out_net = OutNet2(net1, net2)
res = out_net(Tensor([1], ms.float32))
print("res:", res)

View File

@ -411,8 +411,8 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
}
TEST_F(TestComposite, test_UnpackCall_3args) {
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 3);
auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
@ -429,7 +429,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@ -439,8 +439,8 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
}
TEST_F(TestComposite, test_UnpackCall_5args) {
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 5);
MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 5);
auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
@ -457,7 +457,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}