From e6f82af849f5cc45e12c423e4db8e1bd62f29b71 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Thu, 3 Sep 2020 15:37:30 +0800 Subject: [PATCH] add cell class to c++ --- mindspore/_extends/parse/parser.py | 7 +- .../optimizer/irpass/incorporate_getitem.h | 1 - mindspore/ccsrc/pipeline/jit/action.cc | 23 ++- .../pipeline/jit/parse/data_converter.cc | 10 +- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 135 ++++++++++++++---- .../ccsrc/pipeline/jit/parse/parse_base.h | 3 + mindspore/ccsrc/pybind_api/ir/cell_py.cc | 50 +++++++ mindspore/ccsrc/pybind_api/ir/cell_py.h | 44 ++++++ mindspore/core/ir/cell.cc | 94 ++++++++++++ mindspore/core/ir/cell.h | 69 +++++++++ mindspore/core/ir/func_graph_extends.cc | 3 +- mindspore/core/utils/label.cc | 3 +- mindspore/nn/cell.py | 50 ++++++- mindspore/nn/layer/container.py | 4 +- .../parallel/test_get_parameter_layout.py | 2 +- .../python/parallel/test_split_grad_sens.py | 2 +- 16 files changed, 454 insertions(+), 46 deletions(-) create mode 100644 mindspore/ccsrc/pybind_api/ir/cell_py.cc create mode 100644 mindspore/ccsrc/pybind_api/ir/cell_py.h create mode 100644 mindspore/core/ir/cell.cc create mode 100644 mindspore/core/ir/cell.h diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 2c3168bc629..df65de8be8b 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -194,9 +194,12 @@ def get_object_key(obj): obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args) obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj)) else: + # `` + # -> `xxxxxxx` + tag = str(obj.__class__)[8:-2] if hasattr(obj, "cell_init_args"): - obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args) - obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj)) + obj_key = "%s_ID" % (tag + obj.cell_init_args) + obj_id = "%s_ID%d" % (tag, id(obj)) logger.debug("obj_key %s obj_id = %s", obj_key, obj_id) # method has same id of different instance diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 9bfa3669406..998eb0c8e8b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor { } } - // (void)mng->Replace(new_fg_parameters[param_i], new_param); new_parameters.push_back(new_param); curr_input_idx++; } diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index b2434d5a1c1..37ed3f0ef23 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -25,6 +25,7 @@ #include "ir/func_graph_cloner.h" #include "ir/param_info.h" +#include "ir/cell.h" #include "frontend/parallel/costmodel_context.h" #include "frontend/parallel/context.h" #include "pipeline/jit/pass.h" @@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) { parse::python_adapter::set_python_env_flag(true); parse::python_adapter::SetPythonPath(dir); - FuncGraphPtr fg = parse::ConvertToFuncGraph(input); - if (fg == nullptr) { - MS_LOG(EXCEPTION) << "Parse error."; + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(input, &converted_ret, true); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input)); } - res->set_func_graph(fg); + + FuncGraphPtr top_graph = nullptr; + if (py::isinstance(input)) { + top_graph = parse::MakeTopGraph(input, converted_ret); + } else if (converted_ret->isa()) { + top_graph = converted_ret->cast(); + } else { + MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell."; + } + parse::Parser::UpdateTopFuncGraph(top_graph); + + res->set_func_graph(top_graph); FuncGraphManagerPtr manager = res->manager(); if (manager == nullptr) { MS_LOG(EXCEPTION) << "Manager is nullptr."; } - manager->AddFuncGraph(fg); + manager->AddFuncGraph(top_graph); return true; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 2f43c7c0cdc..141c636ec73 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -27,6 +27,7 @@ #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "ir/func_graph_cloner.h" +#include "ir/cell.h" #include "utils/symbolic.h" #include "utils/ms_context.h" @@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { return true; } -bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { +bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { + auto obj = py::cast(cell); FuncGraphPtr func_graph = ConvertToFuncGraph(obj); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse resolve function error."; @@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { // Create the namespace for common class instance // When the obj is Cell, default parse the 'construct' - if (data_converter::IsCellInstance(obj)) { - return ConvertCellObjToFuncGraph(obj, data); - } - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); @@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature ret = ConvertTuple(obj, &converted, use_signature); } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { ret = ConvertCellList(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + return ConvertCellObjToFuncGraph(obj.cast(), data); } else if (py::isinstance(obj)) { ret = ConvertList(obj, &converted, use_signature); } else if (py::isinstance(obj)) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 9e7ed723b58..e7d5417e6f0 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -140,6 +140,67 @@ void Parser::CleanParserResource() { ScopeManager::GetInstance().ClearScope(); } +AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { + MS_EXCEPTION_IF_NULL(func_graph); + auto value = py::cast(obj); + // parameter object should not be none + if (value == nullptr || !value->is_parameter()) { + MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object."; + } + + // get the parameter name from parameter object + auto param_name = value->param_info()->name(); + + auto top_graph = func_graph; + // if the parameter node has been created , return it + AnfNodePtr para_node = nullptr; + for (auto param : top_graph->parameters()) { + auto param_node = dyn_cast(param); + if (param_node != nullptr && param_node->name() == param_name) { + para_node = param; + break; + } + } + if (para_node == nullptr) { + auto node = top_graph->AddWeightParameter(param_name); + + node->set_default_param(value); + // set_abstract for parameter + auto abs = value->ToAbstract(); + // boarden value + abs = abs->Broaden(); + node->set_abstract(abs); + para_node = node; + } + return para_node; +} + +void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) { + auto params = py::list(cell.attr("get_parameters")()).cast>(); + for (size_t i = 0; i < params.size(); i++) { + (void)AppendParameterObj(top_graph, params[i]); + } +} + +void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr &ast) { + // check whether the functions refered by this function and itself are missing 'return' statement + auto mng = Manage(fn, false); + for (auto func_graph : mng->func_graphs()) { + if (func_graph->get_return() != nullptr) { + continue; + } + py::object node = ast->GetAstNode(); + py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + py::str desc = + python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); + MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast() << "."; + } + // clear manager info after checking missing return + for (auto fg : mng->func_graphs()) { + fg->ClearAllManagerInfo(); + } +} + FuncGraphPtr Parser::ParseFuncGraph() { // get ast FunctionDef node py::object node = ast_->GetAstNode(); @@ -152,22 +213,7 @@ FuncGraphPtr Parser::ParseFuncGraph() { RemoveUnnecessaryPhis(); MS_EXCEPTION_IF_NULL(pFnBlock); - - // check whether the functions refered by this function and itself are missing 'return' statement - auto mng = Manage(pFnBlock->func_graph(), false); - for (auto func_graph : mng->func_graphs()) { - if (func_graph->get_return() != nullptr) { - continue; - } - py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - py::str desc = - python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]); - MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast() << "."; - } - // clear manager info after checking missing return - for (auto fg : mng->func_graphs()) { - fg->ClearAllManagerInfo(); - } + CheckFuncReturn(pFnBlock->func_graph(), ast_); return pFnBlock->func_graph(); } @@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); } +CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node, + const std::vector &packed_arguments) { + std::vector unpack_call_nodes; + auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); + unpack_call_nodes.push_back(unpack_call_op); + unpack_call_nodes.push_back(call_function_anf_node); + (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), + [](AnfNodePtr node) -> AnfNodePtr { return node; }); + CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes); + return unpack_call; +} + AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, const std::vector &packed_arguments, const std::vector &group_arguments, bool need_unpack) const { // if there is keyword arguments or starred, using an unpack_call op to unpack the argument if (need_unpack) { - std::vector unpack_call_nodes; - auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); - unpack_call_nodes.push_back(unpack_call_op); - unpack_call_nodes.push_back(call_function_anf_node); - (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), - [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); - return unpack_call; + return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments); } // else there is no keyword arguments and starred, parsed as normal arguments without unpack std::vector func_call_nodes; @@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { return true; } +FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { + auto func_graph = std::make_shared(); + func_graph->debug_info()->set_name("top"); + + // def top(*arg, *kwargs): + auto param_vargs = func_graph->add_parameter(); + auto args_name = "args"; + param_vargs->set_name(args_name); + param_vargs->debug_info()->set_name(args_name); + + auto param_vkwargs = func_graph->add_parameter(); + args_name = "kwargs"; + param_vkwargs->set_name(args_name); + param_vkwargs->debug_info()->set_name(args_name); + + func_graph->set_has_vararg(true); + func_graph->set_has_kwarg(true); + func_graph->set_kwonlyargs_count(0); + + // cell_obj + parse::UpdateFuncGraphFlags(cell, func_graph); + // top graph's construct flag + if (py::hasattr(cell, "construct")) { + parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); + } + + UpdataParam(func_graph, cell); + + // ret = cell_obj(*arg, *kwargs) + auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs}); + + // return ret + func_graph->set_output(call_fn); + MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); + return func_graph; +} } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 25818214138..8f73c8aca3a 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, // Parse the python object to graph FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +// add wrap for cell top graph. +FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr); + } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/ir/cell_py.cc b/mindspore/ccsrc/pybind_api/ir/cell_py.cc new file mode 100644 index 00000000000..efd80209a26 --- /dev/null +++ b/mindspore/ccsrc/pybind_api/ir/cell_py.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pybind_api/ir/cell_py.h" +#include + +#include "pybind_api/api_register.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) { + std::string attr_name = name; + ValuePtr converted_ret = nullptr; + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module"; + } + bool converted = parse::ConvertData(obj, &converted_ret, true); + if (!converted) { + MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj)); + } else { + MS_LOG(DEBUG) << cell->ToString() << " add attr " << attr_name << converted_ret->ToString(); + cell->AddAttr(attr_name, converted_ret); + } +} +// Define python 'Cell' class. +REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) { + (void)py::class_>(*m, "Cell_") + .def(py::init()) + .def("__str__", &Cell::ToString) + .def("_add_attr", &CellPy::AddAttr, "Add Cell attr.") + .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.") + .def( + "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; }, + "construct"); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/ir/cell_py.h b/mindspore/ccsrc/pybind_api/ir/cell_py.h new file mode 100644 index 00000000000..7011fdafb82 --- /dev/null +++ b/mindspore/ccsrc/pybind_api/ir/cell_py.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_ +#define MINDSPORE_CCSRC_UTILS_CELL_PY_H_ + +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +#include "ir/cell.h" + +namespace py = pybind11; +// brief mindspore namespace. +// +// mindspore namespace is the top level namespace of Mindsporeession project. +// Other namespace should be a sub namespace of mindspore namespace in the ME project. +namespace mindspore { + +// Cell python wrapper and adapter class. +class CellPy { + public: + static void AddAttr(CellPtr cell, const std::string &name, const py::object &obj); +}; + +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_ diff --git a/mindspore/core/ir/cell.cc b/mindspore/core/ir/cell.cc new file mode 100644 index 00000000000..730b8d92b27 --- /dev/null +++ b/mindspore/core/ir/cell.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/cell.h" + +#include +#include +#include + +#include "abstract/abstract_value.h" + +namespace mindspore { +using mindspore::abstract::AbstractFunction; + +abstract::AbstractBasePtr Cell::ToAbstract() { + /* + std::vector abs_attrs; + std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs), + [](std::pair attr) -> abstract::AbstractAttribute { + return std::make_pair(attr.first, attr.second->ToAbstract()); + }); + auto abs = std::make_shared(shared_from_base(), abs_attrs); + abs->set_value(shared_from_base()); + return abs; + */ + return nullptr; +} + +bool Cell::operator==(const Value &other) const { + if (other.isa()) { + auto other_prim = static_cast(other); + return *this == other_prim; + } else { + return false; + } +} + +bool Cell::operator==(const Cell &other) const { + if (name() != other.name()) { + return false; + } + if (attrs_.size() != other.attrs_.size()) { + return false; + } + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { + if (item.second == nullptr) { + return false; + } + auto iter = other.attrs_.find(item.first); + if (iter == other.attrs_.end()) { + return false; + } + return *item.second == *iter->second; + }); + return all; +} + +std::string Cell::GetAttrString() const { + std::ostringstream buffer; + bool begin = true; + buffer << "{" << std::endl; + for (auto &attr : attrs_) { + if (!begin) { + buffer << ", " << std::endl; + } else { + begin = false; + } + buffer << attr.first << ":" << attr.second->ToString(); + } + buffer << "}"; + return buffer.str(); +} + +std::string Cell::ToString() const { + std::ostringstream buffer; + buffer << "Cell " << name(); + return buffer.str(); +} + +void Cell::DelAttr(const std::string &name) { attrs_.erase(name); } +} // namespace mindspore diff --git a/mindspore/core/ir/cell.h b/mindspore/core/ir/cell.h new file mode 100644 index 00000000000..1622c5a81d0 --- /dev/null +++ b/mindspore/core/ir/cell.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_CELL_H_ +#define MINDSPORE_CCSRC_IR_CELL_H_ + +#include +#include +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "utils/misc.h" + +namespace mindspore { +using abstract::AbstractBasePtr; +using abstract::AbstractBasePtrList; +// value for Cell + +class Cell : public Named { + public: + explicit Cell(const std::string &name) : Named(name) {} + MS_DECLARE_PARENT(Cell, Named); + abstract::AbstractBasePtr ToAbstract() override; + std::string ToString() const override; + std::string GetAttrString() const; + + const std::unordered_map &attrs() const { return attrs_; } + void set_attrs(const std::unordered_map &attrs_input) { attrs_ = attrs_input; } + + void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; } + void DelAttr(const std::string &name); + ValuePtr GetAttr(const std::string &attr_name) const { + auto iter = attrs_.find(attr_name); + return iter == attrs_.cend() ? nullptr : iter->second; + } + + bool HasAttr(const std::string &attr_name) const { + auto iter = attrs_.find(attr_name); + return !(iter == attrs_.cend()); + } + + bool operator==(const Value &other) const override; + bool operator==(const Cell &other) const; + ~Cell() override = default; + const bool parse_info_ = true; + + private: + std::unordered_map attrs_; +}; + +using CellPtr = std::shared_ptr; + +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_CELL_H_ diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 0faaa63f1cf..2416da08238 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count << " were given."; } + auto varg_name = specialized_graph->GetVariableArgName(); // for python variable argument input , there is no upper limit for (int i = 0; i < variable_args_count; ++i) { ParameterPtr p = std::make_shared(specialized_graph); - std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); + std::string param_name = varg_name + std::to_string(i); p->set_name(param_name); MS_EXCEPTION_IF_NULL(p->debug_info()); p->debug_info()->set_name(param_name); diff --git a/mindspore/core/utils/label.cc b/mindspore/core/utils/label.cc index ef4ce9ee3cb..f444d3fa186 100644 --- a/mindspore/core/utils/label.cc +++ b/mindspore/core/utils/label.cc @@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe while (temp_info != nullptr) { if (temp_info->trace_info() != nullptr) { if (temp_info->trace_info()->isa() || temp_info->trace_info()->isa() || - temp_info->trace_info()->isa()) { + temp_info->trace_info()->isa() || + temp_info->trace_info()->isa() || temp_info->trace_info()->isa()) { break; } trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e50e93e8143..0baca78d54a 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -24,14 +24,14 @@ from ..common import dtype as mstype from ..common.api import _executor, _pynative_exec from .._checkparam import _check_str_by_regular from ..common.parameter import Parameter, ParameterTuple -from .._c_expression import init_backend +from .._c_expression import init_backend, Cell_ from ..ops.primitive import Primitive from ..ops.operations import HookBackward from ..ops.functional import cast from ..parallel._tensor import _load_tensor_by_layout from ..common.tensor import Tensor -class Cell: +class Cell(Cell_): """ Base class for all neural networks. @@ -58,14 +58,21 @@ class Cell: >>> def construct(self, x): >>> return self.relu(x) """ + IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', + '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', + '_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode', + '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced', + 'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type'] def __init__(self, auto_prefix=True, flags=None): + Cell_.__init__(self, self._cell_tag) self._params = OrderedDict() self._cells = OrderedDict() self._params_list = OrderedDict() self.training = False self.requires_grad = False self.pynative = False + self._attr_synced = False self._param_prefix = '' self._auto_prefix = auto_prefix self._scope = None @@ -92,6 +99,12 @@ class Cell: def already_run(self): return self._already_run + @property + def _cell_tag(self): + # `` + # -> `xxxxxxx` + return str(self.__class__)[8:-2] + @already_run.setter def already_run(self, value): self._already_run = value @@ -222,6 +235,7 @@ class Cell: del self._cells[name] else: object.__delattr__(self, name) + self._attr_synced = False def cast_inputs(self, inputs, dst_type): res = list() @@ -277,6 +291,34 @@ class Cell: self._already_run = True return output + def _add_attr(self, name, value): + if name and name[:2] != '__' and name not in Cell.IGNORE_LIST: + super(Cell, self)._add_attr(name, value) + + def _sync_attr_for_compile(self): + """Sync the attr to c++ object.""" + if self._attr_synced: + return + cells = self.__dict__.get('_cells') + for key in cells: + cell = cells[key] + cell._sync_attr_for_compile() + self._add_attr(key, cell) + params = self.__dict__.get('_params') + for key in params: + if '.' in key: + continue + param = params[key] + self._add_attr(key, param) + params_list = self.__dict__.get('_params_list') + for key in params_list: + params_list_item = params_list[key] + self._add_attr(key, params_list_item) + for key in self.__dict__: + value = self.__dict__[key] + self._add_attr(key, value) + self._attr_synced = True + def __setattr__(self, name, value): cells = self.__dict__.get('_cells') params = self.__dict__.get('_params') @@ -329,6 +371,8 @@ class Cell: if isinstance(value, Primitive): value.set_prim_instance_name(name) object.__setattr__(self, name, value) + if name not in Cell.IGNORE_LIST: + self._attr_synced = False def extend_repr(self): """ @@ -451,7 +495,7 @@ class Cell: Object, the result of executing. """ self._auto_parallel_compile_and_run = True - _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) + self.compile(*inputs) if self._auto_parallel_mode: if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index 843a0784d54..62a50d12550 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -14,7 +14,7 @@ # ============================================================================ """container""" from collections import OrderedDict -from abc import abstractmethod, ABCMeta +from abc import abstractmethod from ..cell import Cell __all__ = ['SequentialCell', 'CellList'] @@ -34,7 +34,7 @@ def _valid_cell(cell): raise TypeError('Cell {} is not subclass of Cell'.format(cell)) -class _CellListBase(metaclass=ABCMeta): +class _CellListBase(): """ An interface for base the cell as list. diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index 01248d13aee..48b40e53841 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -51,7 +51,7 @@ def test_get_parameter_layout(): exe.compile(net, x, phase='train', auto_parallel_mode=True) x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] - expect_dict = {'x': x_layout, 'w1': weight_layout} + expect_dict = {'args0': x_layout, 'w1': weight_layout} # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut assert net.parameter_layout_dict == expect_dict diff --git a/tests/ut/python/parallel/test_split_grad_sens.py b/tests/ut/python/parallel/test_split_grad_sens.py index 643832aa77a..c77260faa78 100644 --- a/tests/ut/python/parallel/test_split_grad_sens.py +++ b/tests/ut/python/parallel/test_split_grad_sens.py @@ -125,7 +125,7 @@ def test_grad_sens_parameter_type(): y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] - expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} + expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout} assert net.parameter_layout_dict == expect_dict