forked from mindspore-Ecosystem/mindspore
add cell class to c++
This commit is contained in:
parent
879a519136
commit
e6f82af849
|
@ -194,9 +194,12 @@ def get_object_key(obj):
|
|||
obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
|
||||
obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
|
||||
else:
|
||||
# `<class 'xxxxxxx'>`
|
||||
# -> `xxxxxxx`
|
||||
tag = str(obj.__class__)[8:-2]
|
||||
if hasattr(obj, "cell_init_args"):
|
||||
obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args)
|
||||
obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj))
|
||||
obj_key = "%s_ID" % (tag + obj.cell_init_args)
|
||||
obj_id = "%s_ID%d" % (tag, id(obj))
|
||||
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
|
||||
|
||||
# method has same id of different instance
|
||||
|
|
|
@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor {
|
|||
}
|
||||
}
|
||||
|
||||
// (void)mng->Replace(new_fg_parameters[param_i], new_param);
|
||||
new_parameters.push_back(new_param);
|
||||
curr_input_idx++;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/cell.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) {
|
|||
parse::python_adapter::set_python_env_flag(true);
|
||||
parse::python_adapter::SetPythonPath(dir);
|
||||
|
||||
FuncGraphPtr fg = parse::ConvertToFuncGraph(input);
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parse error.";
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(input, &converted_ret, true);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input));
|
||||
}
|
||||
res->set_func_graph(fg);
|
||||
|
||||
FuncGraphPtr top_graph = nullptr;
|
||||
if (py::isinstance<Cell>(input)) {
|
||||
top_graph = parse::MakeTopGraph(input, converted_ret);
|
||||
} else if (converted_ret->isa<FuncGraph>()) {
|
||||
top_graph = converted_ret->cast<FuncGraphPtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
|
||||
}
|
||||
parse::Parser::UpdateTopFuncGraph(top_graph);
|
||||
|
||||
res->set_func_graph(top_graph);
|
||||
|
||||
FuncGraphManagerPtr manager = res->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Manager is nullptr.";
|
||||
}
|
||||
manager->AddFuncGraph(fg);
|
||||
manager->AddFuncGraph(top_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/cell.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
|
@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
|
||||
bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
|
||||
auto obj = py::cast(cell);
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parse resolve function error.";
|
||||
|
@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
|
|||
if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
|
||||
// Create the namespace for common class instance
|
||||
// When the obj is Cell, default parse the 'construct'
|
||||
if (data_converter::IsCellInstance(obj)) {
|
||||
return ConvertCellObjToFuncGraph(obj, data);
|
||||
}
|
||||
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
|
||||
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
|
@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
ret = ConvertTuple(obj, &converted, use_signature);
|
||||
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
|
||||
ret = ConvertCellList(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<Cell>(obj)) {
|
||||
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
|
||||
} else if (py::isinstance<py::list>(obj)) {
|
||||
ret = ConvertList(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<py::module>(obj)) {
|
||||
|
|
|
@ -140,6 +140,67 @@ void Parser::CleanParserResource() {
|
|||
ScopeManager::GetInstance().ClearScope();
|
||||
}
|
||||
|
||||
AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
||||
// parameter object should not be none
|
||||
if (value == nullptr || !value->is_parameter()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object.";
|
||||
}
|
||||
|
||||
// get the parameter name from parameter object
|
||||
auto param_name = value->param_info()->name();
|
||||
|
||||
auto top_graph = func_graph;
|
||||
// if the parameter node has been created , return it
|
||||
AnfNodePtr para_node = nullptr;
|
||||
for (auto param : top_graph->parameters()) {
|
||||
auto param_node = dyn_cast<Parameter>(param);
|
||||
if (param_node != nullptr && param_node->name() == param_name) {
|
||||
para_node = param;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (para_node == nullptr) {
|
||||
auto node = top_graph->AddWeightParameter(param_name);
|
||||
|
||||
node->set_default_param(value);
|
||||
// set_abstract for parameter
|
||||
auto abs = value->ToAbstract();
|
||||
// boarden value
|
||||
abs = abs->Broaden();
|
||||
node->set_abstract(abs);
|
||||
para_node = node;
|
||||
}
|
||||
return para_node;
|
||||
}
|
||||
|
||||
void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) {
|
||||
auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>();
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
(void)AppendParameterObj(top_graph, params[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
|
||||
// check whether the functions refered by this function and itself are missing 'return' statement
|
||||
auto mng = Manage(fn, false);
|
||||
for (auto func_graph : mng->func_graphs()) {
|
||||
if (func_graph->get_return() != nullptr) {
|
||||
continue;
|
||||
}
|
||||
py::object node = ast->GetAstNode();
|
||||
py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
py::str desc =
|
||||
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
|
||||
}
|
||||
// clear manager info after checking missing return
|
||||
for (auto fg : mng->func_graphs()) {
|
||||
fg->ClearAllManagerInfo();
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr Parser::ParseFuncGraph() {
|
||||
// get ast FunctionDef node
|
||||
py::object node = ast_->GetAstNode();
|
||||
|
@ -152,22 +213,7 @@ FuncGraphPtr Parser::ParseFuncGraph() {
|
|||
RemoveUnnecessaryPhis();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(pFnBlock);
|
||||
|
||||
// check whether the functions refered by this function and itself are missing 'return' statement
|
||||
auto mng = Manage(pFnBlock->func_graph(), false);
|
||||
for (auto func_graph : mng->func_graphs()) {
|
||||
if (func_graph->get_return() != nullptr) {
|
||||
continue;
|
||||
}
|
||||
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
py::str desc =
|
||||
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
|
||||
}
|
||||
// clear manager info after checking missing return
|
||||
for (auto fg : mng->func_graphs()) {
|
||||
fg->ClearAllManagerInfo();
|
||||
}
|
||||
CheckFuncReturn(pFnBlock->func_graph(), ast_);
|
||||
|
||||
return pFnBlock->func_graph();
|
||||
}
|
||||
|
@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
|
||||
const std::vector<AnfNodePtr> &packed_arguments,
|
||||
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
|
||||
// if there is keyword arguments or starred, using an unpack_call op to unpack the argument
|
||||
if (need_unpack) {
|
||||
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
|
||||
const std::vector<AnfNodePtr> &packed_arguments) {
|
||||
std::vector<AnfNodePtr> unpack_call_nodes;
|
||||
auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
|
||||
unpack_call_nodes.push_back(unpack_call_op);
|
||||
unpack_call_nodes.push_back(call_function_anf_node);
|
||||
(void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
|
||||
[](AnfNodePtr node) -> AnfNodePtr { return node; });
|
||||
CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes);
|
||||
CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes);
|
||||
return unpack_call;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
|
||||
const std::vector<AnfNodePtr> &packed_arguments,
|
||||
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
|
||||
// if there is keyword arguments or starred, using an unpack_call op to unpack the argument
|
||||
if (need_unpack) {
|
||||
return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
|
||||
}
|
||||
// else there is no keyword arguments and starred, parsed as normal arguments without unpack
|
||||
std::vector<AnfNodePtr> func_call_nodes;
|
||||
|
@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
func_graph->debug_info()->set_name("top");
|
||||
|
||||
// def top(*arg, *kwargs):
|
||||
auto param_vargs = func_graph->add_parameter();
|
||||
auto args_name = "args";
|
||||
param_vargs->set_name(args_name);
|
||||
param_vargs->debug_info()->set_name(args_name);
|
||||
|
||||
auto param_vkwargs = func_graph->add_parameter();
|
||||
args_name = "kwargs";
|
||||
param_vkwargs->set_name(args_name);
|
||||
param_vkwargs->debug_info()->set_name(args_name);
|
||||
|
||||
func_graph->set_has_vararg(true);
|
||||
func_graph->set_has_kwarg(true);
|
||||
func_graph->set_kwonlyargs_count(0);
|
||||
|
||||
// cell_obj
|
||||
parse::UpdateFuncGraphFlags(cell, func_graph);
|
||||
// top graph's construct flag
|
||||
if (py::hasattr(cell, "construct")) {
|
||||
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
|
||||
}
|
||||
|
||||
UpdataParam(func_graph, cell);
|
||||
|
||||
// ret = cell_obj(*arg, *kwargs)
|
||||
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});
|
||||
|
||||
// return ret
|
||||
func_graph->set_output(call_fn);
|
||||
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
|
|||
// Parse the python object to graph
|
||||
FuncGraphPtr ParsePythonCode(const py::object &obj,
|
||||
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
|
||||
// add wrap for cell top graph.
|
||||
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr);
|
||||
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind_api/ir/cell_py.h"
|
||||
#include <string>
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) {
|
||||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
if (!converted) {
|
||||
MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
} else {
|
||||
MS_LOG(DEBUG) << cell->ToString() << " add attr " << attr_name << converted_ret->ToString();
|
||||
cell->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
}
|
||||
// Define python 'Cell' class.
|
||||
REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) {
|
||||
(void)py::class_<Cell, std::shared_ptr<Cell>>(*m, "Cell_")
|
||||
.def(py::init<std::string &>())
|
||||
.def("__str__", &Cell::ToString)
|
||||
.def("_add_attr", &CellPy::AddAttr, "Add Cell attr.")
|
||||
.def("_del_attr", &Cell::DelAttr, "Delete Cell attr.")
|
||||
.def(
|
||||
"construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; },
|
||||
"construct");
|
||||
}));
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_CELL_PY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/numpy.h"
|
||||
|
||||
#include "ir/cell.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
// brief mindspore namespace.
|
||||
//
|
||||
// mindspore namespace is the top level namespace of Mindsporeession project.
|
||||
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
|
||||
namespace mindspore {
|
||||
|
||||
// Cell python wrapper and adapter class.
|
||||
class CellPy {
|
||||
public:
|
||||
static void AddAttr(CellPtr cell, const std::string &name, const py::object &obj);
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/cell.h"
|
||||
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
|
||||
abstract::AbstractBasePtr Cell::ToAbstract() {
|
||||
/*
|
||||
std::vector<abstract::AbstractAttribute> abs_attrs;
|
||||
std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs),
|
||||
[](std::pair<std::string, ValuePtr> attr) -> abstract::AbstractAttribute {
|
||||
return std::make_pair(attr.first, attr.second->ToAbstract());
|
||||
});
|
||||
auto abs = std::make_shared<abstract::AbstractCell>(shared_from_base<Named>(), abs_attrs);
|
||||
abs->set_value(shared_from_base<Value>());
|
||||
return abs;
|
||||
*/
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool Cell::operator==(const Value &other) const {
|
||||
if (other.isa<Cell>()) {
|
||||
auto other_prim = static_cast<const Cell &>(other);
|
||||
return *this == other_prim;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool Cell::operator==(const Cell &other) const {
|
||||
if (name() != other.name()) {
|
||||
return false;
|
||||
}
|
||||
if (attrs_.size() != other.attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = other.attrs_.find(item.first);
|
||||
if (iter == other.attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
}
|
||||
|
||||
std::string Cell::GetAttrString() const {
|
||||
std::ostringstream buffer;
|
||||
bool begin = true;
|
||||
buffer << "{" << std::endl;
|
||||
for (auto &attr : attrs_) {
|
||||
if (!begin) {
|
||||
buffer << ", " << std::endl;
|
||||
} else {
|
||||
begin = false;
|
||||
}
|
||||
buffer << attr.first << ":" << attr.second->ToString();
|
||||
}
|
||||
buffer << "}";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Cell::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Cell " << name();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
void Cell::DelAttr(const std::string &name) { attrs_.erase(name); }
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_CELL_H_
|
||||
#define MINDSPORE_CCSRC_IR_CELL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/misc.h"
|
||||
|
||||
namespace mindspore {
|
||||
using abstract::AbstractBasePtr;
|
||||
using abstract::AbstractBasePtrList;
|
||||
// value for Cell
|
||||
|
||||
class Cell : public Named {
|
||||
public:
|
||||
explicit Cell(const std::string &name) : Named(name) {}
|
||||
MS_DECLARE_PARENT(Cell, Named);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
std::string ToString() const override;
|
||||
std::string GetAttrString() const;
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs_input) { attrs_ = attrs_input; }
|
||||
|
||||
void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
|
||||
void DelAttr(const std::string &name);
|
||||
ValuePtr GetAttr(const std::string &attr_name) const {
|
||||
auto iter = attrs_.find(attr_name);
|
||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
bool HasAttr(const std::string &attr_name) const {
|
||||
auto iter = attrs_.find(attr_name);
|
||||
return !(iter == attrs_.cend());
|
||||
}
|
||||
|
||||
bool operator==(const Value &other) const override;
|
||||
bool operator==(const Cell &other) const;
|
||||
~Cell() override = default;
|
||||
const bool parse_info_ = true;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
};
|
||||
|
||||
using CellPtr = std::shared_ptr<Cell>;
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_CELL_H_
|
|
@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
|
|||
MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count
|
||||
<< " were given.";
|
||||
}
|
||||
auto varg_name = specialized_graph->GetVariableArgName();
|
||||
// for python variable argument input , there is no upper limit
|
||||
for (int i = 0; i < variable_args_count; ++i) {
|
||||
ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
|
||||
std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i);
|
||||
std::string param_name = varg_name + std::to_string(i);
|
||||
p->set_name(param_name);
|
||||
MS_EXCEPTION_IF_NULL(p->debug_info());
|
||||
p->debug_info()->set_name(param_name);
|
||||
|
|
|
@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe
|
|||
while (temp_info != nullptr) {
|
||||
if (temp_info->trace_info() != nullptr) {
|
||||
if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() ||
|
||||
temp_info->trace_info()->isa<TraceGenMetaFuncGraph>()) {
|
||||
temp_info->trace_info()->isa<TraceGenMetaFuncGraph>() ||
|
||||
temp_info->trace_info()->isa<TraceGenerateVarArg>() || temp_info->trace_info()->isa<TraceGenerateKwArg>()) {
|
||||
break;
|
||||
}
|
||||
trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label));
|
||||
|
|
|
@ -24,14 +24,14 @@ from ..common import dtype as mstype
|
|||
from ..common.api import _executor, _pynative_exec
|
||||
from .._checkparam import _check_str_by_regular
|
||||
from ..common.parameter import Parameter, ParameterTuple
|
||||
from .._c_expression import init_backend
|
||||
from .._c_expression import init_backend, Cell_
|
||||
from ..ops.primitive import Primitive
|
||||
from ..ops.operations import HookBackward
|
||||
from ..ops.functional import cast
|
||||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..common.tensor import Tensor
|
||||
|
||||
class Cell:
|
||||
class Cell(Cell_):
|
||||
"""
|
||||
Base class for all neural networks.
|
||||
|
||||
|
@ -58,14 +58,21 @@ class Cell:
|
|||
>>> def construct(self, x):
|
||||
>>> return self.relu(x)
|
||||
"""
|
||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode',
|
||||
'_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced',
|
||||
'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
|
||||
|
||||
def __init__(self, auto_prefix=True, flags=None):
|
||||
Cell_.__init__(self, self._cell_tag)
|
||||
self._params = OrderedDict()
|
||||
self._cells = OrderedDict()
|
||||
self._params_list = OrderedDict()
|
||||
self.training = False
|
||||
self.requires_grad = False
|
||||
self.pynative = False
|
||||
self._attr_synced = False
|
||||
self._param_prefix = ''
|
||||
self._auto_prefix = auto_prefix
|
||||
self._scope = None
|
||||
|
@ -92,6 +99,12 @@ class Cell:
|
|||
def already_run(self):
|
||||
return self._already_run
|
||||
|
||||
@property
|
||||
def _cell_tag(self):
|
||||
# `<class 'xxxxxxx'>`
|
||||
# -> `xxxxxxx`
|
||||
return str(self.__class__)[8:-2]
|
||||
|
||||
@already_run.setter
|
||||
def already_run(self, value):
|
||||
self._already_run = value
|
||||
|
@ -222,6 +235,7 @@ class Cell:
|
|||
del self._cells[name]
|
||||
else:
|
||||
object.__delattr__(self, name)
|
||||
self._attr_synced = False
|
||||
|
||||
def cast_inputs(self, inputs, dst_type):
|
||||
res = list()
|
||||
|
@ -277,6 +291,34 @@ class Cell:
|
|||
self._already_run = True
|
||||
return output
|
||||
|
||||
def _add_attr(self, name, value):
|
||||
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
||||
super(Cell, self)._add_attr(name, value)
|
||||
|
||||
def _sync_attr_for_compile(self):
|
||||
"""Sync the attr to c++ object."""
|
||||
if self._attr_synced:
|
||||
return
|
||||
cells = self.__dict__.get('_cells')
|
||||
for key in cells:
|
||||
cell = cells[key]
|
||||
cell._sync_attr_for_compile()
|
||||
self._add_attr(key, cell)
|
||||
params = self.__dict__.get('_params')
|
||||
for key in params:
|
||||
if '.' in key:
|
||||
continue
|
||||
param = params[key]
|
||||
self._add_attr(key, param)
|
||||
params_list = self.__dict__.get('_params_list')
|
||||
for key in params_list:
|
||||
params_list_item = params_list[key]
|
||||
self._add_attr(key, params_list_item)
|
||||
for key in self.__dict__:
|
||||
value = self.__dict__[key]
|
||||
self._add_attr(key, value)
|
||||
self._attr_synced = True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
|
@ -329,6 +371,8 @@ class Cell:
|
|||
if isinstance(value, Primitive):
|
||||
value.set_prim_instance_name(name)
|
||||
object.__setattr__(self, name, value)
|
||||
if name not in Cell.IGNORE_LIST:
|
||||
self._attr_synced = False
|
||||
|
||||
def extend_repr(self):
|
||||
"""
|
||||
|
@ -451,7 +495,7 @@ class Cell:
|
|||
Object, the result of executing.
|
||||
"""
|
||||
self._auto_parallel_compile_and_run = True
|
||||
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
|
||||
self.compile(*inputs)
|
||||
|
||||
if self._auto_parallel_mode:
|
||||
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""container"""
|
||||
from collections import OrderedDict
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from abc import abstractmethod
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['SequentialCell', 'CellList']
|
||||
|
@ -34,7 +34,7 @@ def _valid_cell(cell):
|
|||
raise TypeError('Cell {} is not subclass of Cell'.format(cell))
|
||||
|
||||
|
||||
class _CellListBase(metaclass=ABCMeta):
|
||||
class _CellListBase():
|
||||
"""
|
||||
An interface for base the cell as list.
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_get_parameter_layout():
|
|||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
expect_dict = {'args0': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ def test_grad_sens_parameter_type():
|
|||
y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]]
|
||||
b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]]
|
||||
sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]]
|
||||
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
|
||||
expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout}
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue