!21741 clean code for master

Merge pull request !21741 from changzherui/clean_code_ma
This commit is contained in:
i-robot 2021-08-19 02:21:35 +00:00 committed by Gitee
commit 115da9f797
16 changed files with 88 additions and 40 deletions

View File

@ -148,7 +148,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer. Check argument integer.
Example: Example:
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 - number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
""" """
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
prim_name = f'in `{prim_name}`' if prim_name else '' prim_name = f'in `{prim_name}`' if prim_name else ''

View File

@ -178,6 +178,7 @@ void IrExportBuilder::BuildModelInfo() {
} }
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) { void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
MS_EXCEPTION_IF_NULL(func_graph);
mind_ir::GraphProto *graph_proto = model_.mutable_graph(); mind_ir::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString()); graph_proto->set_name(func_graph->ToString());
graph_proto->set_bprop_hash(func_graph->bprop_hash()); graph_proto->set_bprop_hash(func_graph->bprop_hash());
@ -225,7 +226,10 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::Gr
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) { bool save_tensor_data) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_proto);
for (auto &item : func_graph->parameters()) { for (auto &item : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(item);
auto param = item->cast<ParameterPtr>(); auto param = item->cast<ParameterPtr>();
if (param == nullptr) { if (param == nullptr) {
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
@ -296,6 +300,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
} }
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) { if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>(); auto tensor = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto elem_type = tensor->element(); auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
@ -328,6 +333,7 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name("value0"); tensor_proto->set_name("value0");
auto data = value->cast<tensor::TensorPtr>(); auto data = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(data);
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type(); auto dtype = data->data_type();
auto shape = data->shape_c(); auto shape = data->shape_c();
@ -362,6 +368,7 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
bool is_only_return = true; bool is_only_return = true;
for (const AnfNodePtr &node : nodes) { for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
continue; continue;
@ -380,13 +387,16 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
} }
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
if (node->size() != 2) { MS_EXCEPTION_IF_NULL(node);
const int OutputSize = 2;
if (node->size() != OutputSize) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
} }
AnfNodePtr arg = node->input(1); AnfNodePtr arg = node->input(1);
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output(); mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node); std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name); output_proto->set_name(output_name);
MS_EXCEPTION_IF_NULL(last_node_);
last_node_->set_output(0, output_name); last_node_->set_output(0, output_name);
SetValueInfoProto(arg, output_proto); SetValueInfoProto(arg, output_proto);
} }
@ -396,9 +406,11 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
std::string type_name = ""; std::string type_name = "";
if (IsValueNode<Primitive>(node)) { if (IsValueNode<Primitive>(node)) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(prim);
type_name = prim->ToString(); type_name = prim->ToString();
} else if (IsValueNode<FuncGraph>(node)) { } else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
MS_EXCEPTION_IF_NULL(fg);
todo_.push_back(fg); todo_.push_back(fg);
type_name = "REF::" + fg->ToString(); type_name = "REF::" + fg->ToString();
} else if (node->isa<CNode>() || node->isa<Parameter>()) { } else if (node->isa<CNode>() || node->isa<Parameter>()) {
@ -416,10 +428,9 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
if (seq_string == nullptr) { MS_EXCEPTION_IF_NULL(type);
MS_LOG(EXCEPTION) << "seq_string is nullptr."; MS_EXCEPTION_IF_NULL(shape);
} MS_EXCEPTION_IF_NULL(seq_string);
if (type->isa<Tuple>()) { if (type->isa<Tuple>()) {
*seq_string += "Tuple["; *seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements(); auto elements = type->cast<TuplePtr>()->elements();
@ -560,8 +571,9 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
} }
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::string node_name = ""; std::string node_name = "";
if ((node != nullptr) && (node->func_graph() != nullptr)) { if (node->func_graph() != nullptr) {
node_name = node->func_graph()->ToString() + ":"; node_name = node->func_graph()->ToString() + ":";
} }
if (node->isa<ValueNode>()) { if (node->isa<ValueNode>()) {
@ -578,7 +590,9 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro
if (node == nullptr || node_proto == nullptr) { if (node == nullptr || node_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
} }
auto value = node->cast<ValueNodePtr>()->value(); auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
node_proto->set_op_type("Constant"); node_proto->set_op_type("Constant");
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value"); attr_proto->set_name("value");
@ -663,6 +677,9 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A
} }
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar:value0"); attr_proto->set_ref_attr_name("scalar:value0");
if (value->isa<StringImm>()) { if (value->isa<StringImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
@ -709,6 +726,9 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
} }
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
if (value->isa<Int>()) { if (value->isa<Int>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
@ -803,6 +823,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
return; return;
} }
for (const auto &item : list_value->value()) { for (const auto &item : list_value->value()) {
MS_EXCEPTION_IF_NULL(item);
if (item->isa<ValueList>()) { if (item->isa<ValueList>()) {
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string); SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
} else { } else {

View File

@ -55,6 +55,7 @@ using Data = ge::op::Data;
namespace { namespace {
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) { std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg);
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1); auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> { auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> vecs; std::vector<AnfNodePtr> vecs;
@ -132,6 +133,7 @@ OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
} }
void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) { void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
MS_EXCEPTION_IF_NULL(init_input);
if (this->training_) { if (this->training_) {
GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop"); auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
@ -237,6 +239,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_); std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
for (auto &it : nodes) { for (auto &it : nodes) {
MS_EXCEPTION_IF_NULL(it);
if (it->isa<ValueNode>()) { if (it->isa<ValueNode>()) {
if (IsValueNode<SymbolicKeyInstance>(it)) { if (IsValueNode<SymbolicKeyInstance>(it)) {
auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it); auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
@ -251,6 +254,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
} }
} else if (IsValueNode<RefKey>(it)) { } else if (IsValueNode<RefKey>(it)) {
auto refkey = GetValueNode<RefKeyPtr>(it); auto refkey = GetValueNode<RefKeyPtr>(it);
MS_EXCEPTION_IF_NULL(refkey);
auto name = refkey->tag(); auto name = refkey->tag();
auto iter = vars_.find(name); // get corresponding variable op auto iter = vars_.find(name); // get corresponding variable op
if (iter != vars_.end()) { if (iter != vars_.end()) {
@ -771,9 +775,10 @@ void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr inpu
case_inputs.emplace_back(node->input(i)); case_inputs.emplace_back(node->input(i));
} }
auto bnode = input_node->input(2)->cast<CNodePtr>(); auto bnode = input_node->input(2)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bnode);
for (size_t i = 1; i < bnode->inputs().size(); i++) { for (size_t i = 1; i < bnode->inputs().size(); i++) {
auto branch_node = bnode->input(i)->cast<CNodePtr>(); auto branch_node = bnode->input(i)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(branch_node);
for (size_t j = 2; j < branch_node->inputs().size(); j++) { for (size_t j = 2; j < branch_node->inputs().size(); j++) {
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
case_inputs.emplace_back(branch_node->input(j)); case_inputs.emplace_back(branch_node->input(j));
@ -1073,7 +1078,9 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
} }
auto manager = func_graph->manager(); auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(node) == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "Can't find node in nodes_users.";
}
auto &users = manager->node_users()[node]; auto &users = manager->node_users()[node];
std::shared_ptr<std::vector<AnfNodePtr>> src_node_list = std::make_shared<std::vector<AnfNodePtr>>(); std::shared_ptr<std::vector<AnfNodePtr>> src_node_list = std::make_shared<std::vector<AnfNodePtr>>();
std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>(); std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
@ -1101,6 +1108,7 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
bool top) { bool top) {
MS_EXCEPTION_IF_NULL(node);
auto func_graph = node->func_graph(); auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager(); auto mng = func_graph->manager();
@ -1356,6 +1364,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
return; return;
} }
auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>(); auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_node);
FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>(); FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
DfGraphConvertor converter(anf_graph); DfGraphConvertor converter(anf_graph);
converter.use_inputs_ = true; converter.use_inputs_ = true;
@ -1449,13 +1458,16 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
} }
void DfGraphConvertor::ConvertTopK(const CNodePtr node) { void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32."; MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
auto value_ptr = node->input(2)->cast<ValueNodePtr>(); auto value_ptr = node->input(2)->cast<ValueNodePtr>();
std::ostringstream ss; std::ostringstream ss;
ss << "op" << value_ptr.get(); ss << "op" << value_ptr.get();
op_draw_name_[value_ptr.get()] = ss.str(); op_draw_name_[value_ptr.get()] = ss.str();
compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl; compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
auto int64_value = value_ptr->value()->cast<Int64ImmPtr>()->value(); MS_EXCEPTION_IF_NULL(value_ptr);
auto input_value = value_ptr->value();
auto int64_value = GetValue<int64_t>(input_value);
OpAdapterPtr adpt = FindAdapter(value_ptr, training_); OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
auto op = adpt->generate(value_ptr); auto op = adpt->generate(value_ptr);
adpt->setAttr(op, "value", static_cast<int32_t>(int64_value)); adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));

View File

@ -105,7 +105,7 @@ class Parameter(Tensor_):
>>> x = Tensor(np.ones((2, 1)), mindspore.float32) >>> x = Tensor(np.ones((2, 1)), mindspore.float32)
>>> print(net(x)) >>> print(net(x))
[[2.]] [[2.]]
>>> _ = net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32)) >>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
>>> print(net(x)) >>> print(net(x))
[[0.]] [[0.]]
""" """

View File

@ -1232,7 +1232,7 @@ class Tensor(Tensor_):
raise ValueError(msg) raise ValueError(msg)
class seed_context: class seed_context:
'''set and restore seed''' """Set and restore seed."""
def __init__(self, init): def __init__(self, init):
self.init = init self.init = init

View File

@ -493,6 +493,7 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
shape.push_back(attr_tensor.dims(i)); shape.push_back(attr_tensor.dims(i));
} }
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data(); const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
@ -570,6 +571,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
shape.push_back(attr_tensor.dims(i)); shape.push_back(attr_tensor.dims(i));
} }
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data(); const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
@ -794,9 +796,11 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) { if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix)); auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name)); prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(op_name); prim->set_instance_name(op_name);
} else { } else {
prim = std::make_shared<Primitive>(node_type); prim = std::make_shared<Primitive>(node_type);
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type); prim->set_instance_name(node_type);
} }
} }
@ -940,10 +944,15 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
const mind_ir::ValueInfoProto &output_node = importProto.output(out_size); const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
const std::string &out_tuple = output_node.name(); const std::string &out_tuple = output_node.name();
if (anfnode_build_map_.find(out_tuple) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << "Can't find out_tuple in anfnode_build_map_";
return false;
}
inputs.push_back(anfnode_build_map_[out_tuple]); inputs.push_back(anfnode_build_map_[out_tuple]);
elem.push_back(anfnode_build_map_[out_tuple]->abstract()); elem.push_back(anfnode_build_map_[out_tuple]->abstract());
} }
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(maketuple_ptr);
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
inputs.clear(); inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn)); inputs.push_back(NewValueNode(prim::kPrimReturn));
@ -992,8 +1001,7 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
} }
} }
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); return BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
} }
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) { bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {

View File

@ -422,7 +422,6 @@ def _get_stack_info(frame):
Returns: Returns:
str, the string of the stack information. str, the string of the stack information.
""" """
sinfo = None
stack_prefix = 'Stack (most recent call last):\n' stack_prefix = 'Stack (most recent call last):\n'
sinfo = stack_prefix + "".join(traceback.format_stack(frame)) sinfo = stack_prefix + "".join(traceback.format_stack(frame))
return sinfo return sinfo

View File

@ -339,7 +339,7 @@ class Cell(Cell_):
def run_construct(self, cast_inputs, kwargs): def run_construct(self, cast_inputs, kwargs):
if self.enable_hook: if self.enable_hook:
output = self._hook_construct(*cast_inputs, **kwargs) output = self._hook_construct(*cast_inputs)
else: else:
output = self.construct(*cast_inputs, **kwargs) output = self.construct(*cast_inputs, **kwargs)
return output return output
@ -1209,7 +1209,7 @@ class Cell(Cell_):
self.add_flags(auto_parallel=True) self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name() self._get_construct_inputs_number_and_name()
def _hook_construct(self, *inputs, **kwargs): def _hook_construct(self, *inputs):
"""Hook construct method to replace original construct method when hook function enabled.""" """Hook construct method to replace original construct method when hook function enabled."""
inputs = self._backward_hook(*inputs) inputs = self._backward_hook(*inputs)
inputs = self.construct(inputs) inputs = self.construct(inputs)

View File

@ -81,9 +81,7 @@ def _column_or_1d(y):
Ravel column or 1d numpy array, otherwise raise an error. Ravel column or 1d numpy array, otherwise raise an error.
""" """
shape = np.shape(y) shape = np.shape(y)
if len(shape) == 1: if len(shape) == 1 or(len(shape) == 2 and shape[1] == 1):
return np.ravel(y)
if len(shape) == 2 and shape[1] == 1:
return np.ravel(y) return np.ravel(y)
raise ValueError("Bad input shape {0}.".format(shape)) raise ValueError("Bad input shape {0}.".format(shape))

View File

@ -113,13 +113,10 @@ def _check_param(momentum, frequency, lr, cls_name):
def caculate_device_shape(matrix_dim, channel, is_a): def caculate_device_shape(matrix_dim, channel, is_a):
ll = (0)
if is_a: if is_a:
if channel // C0 == 0: if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0 matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim) ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
else:
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll return ll

View File

@ -200,7 +200,7 @@ def rollaxis(x, axis, start=0):
axis = _check_axes_range(axis, ndim) axis = _check_axes_range(axis, ndim)
start = _check_start_normalize(start, ndim) start = _check_start_normalize(start, ndim)
if start - axis >= 0 and start - axis <= 1: if 0 <= start - axis <= 1:
return x return x
perm = F.make_range(0, ndim) perm = F.make_range(0, ndim)
new_perm = None new_perm = None

View File

@ -43,6 +43,12 @@ class LossMonitor(Callback):
self._per_print_times = per_print_times self._per_print_times = per_print_times
def step_end(self, run_context): def step_end(self, run_context):
"""
Print training loss at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args() cb_params = run_context.original_args()
loss = cb_params.net_outputs loss = cb_params.net_outputs

View File

@ -51,7 +51,6 @@ class LearningRateScheduler(Callback):
>>> dataset = create_custom_dataset("custom_dataset_path") >>> dataset = create_custom_dataset("custom_dataset_path")
>>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)], >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
... dataset_sink_mode=False) ... dataset_sink_mode=False)
""" """
def __init__(self, learning_rate_function): def __init__(self, learning_rate_function):
@ -59,6 +58,12 @@ class LearningRateScheduler(Callback):
self.learning_rate_function = learning_rate_function self.learning_rate_function = learning_rate_function
def step_end(self, run_context): def step_end(self, run_context):
"""
Change the learning_rate at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args() cb_params = run_context.original_args()
arr_lr = cb_params.optimizer.learning_rate.asnumpy() arr_lr = cb_params.optimizer.learning_rate.asnumpy()
lr = float(np.array2string(arr_lr)) lr = float(np.array2string(arr_lr))

View File

@ -38,9 +38,21 @@ class TimeMonitor(Callback):
self.epoch_time = time.time() self.epoch_time = time.time()
def epoch_begin(self, run_context): def epoch_begin(self, run_context):
"""
Record time at the begin of epoch.
Args:
run_context (RunContext): Context of the process running.
"""
self.epoch_time = time.time() self.epoch_time = time.time()
def epoch_end(self, run_context): def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
Args:
run_context (RunContext): Context of the process running.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000 epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size step_size = self.data_size
cb_params = run_context.original_args() cb_params = run_context.original_args()

View File

@ -824,7 +824,6 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
if os.path.exists(data_path): if os.path.exists(data_path):
shutil.rmtree(data_path) shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True) os.makedirs(data_path, exist_ok=True)
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
index = 0 index = 0
graphproto = graph_proto() graphproto = graph_proto()
data_size = 0 data_size = 0

View File

@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
class ConvertNetUtils(): class ConvertNetUtils:
""" """
Convert net to thor layer net Convert net to thor layer net
""" """
@ -29,7 +29,6 @@ class ConvertNetUtils():
nn.Embedding: ConvertNetUtils._convert_embedding, nn.Embedding: ConvertNetUtils._convert_embedding,
nn.Conv2d: ConvertNetUtils._convert_conv2d} nn.Conv2d: ConvertNetUtils._convert_conv2d}
@staticmethod @staticmethod
def _convert_dense(subcell): def _convert_dense(subcell):
""" """
@ -64,7 +63,6 @@ class ConvertNetUtils():
new_subcell.bias = subcell.bias new_subcell.bias = subcell.bias
return new_subcell return new_subcell
@staticmethod @staticmethod
def _convert_embedding(subcell): def _convert_embedding(subcell):
""" """
@ -76,7 +74,6 @@ class ConvertNetUtils():
new_subcell.embedding_table = subcell.embedding_table new_subcell.embedding_table = subcell.embedding_table
return new_subcell return new_subcell
@staticmethod @staticmethod
def _convert_conv2d(subcell): def _convert_conv2d(subcell):
""" """
@ -95,7 +92,6 @@ class ConvertNetUtils():
has_bias=has_bias, weight_init=weight) has_bias=has_bias, weight_init=weight)
return new_subcell return new_subcell
def _convert_to_thor_net(self, net): def _convert_to_thor_net(self, net):
""" """
Convert net to thor net Convert net to thor net
@ -114,9 +110,6 @@ class ConvertNetUtils():
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)): elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)):
prefix = subcell.param_prefix prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell) new_subcell = self._convert_method_map[type(subcell)](subcell)
print("subcell name: ", name, "prefix is", prefix, flush=True)
if isinstance(new_subcell, (nn.DenseThor, nn.EmbeddingThor, nn.Conv2dThor)):
print("convert to thor layer success.", flush=True)
new_subcell.update_parameters_name(prefix + '.') new_subcell.update_parameters_name(prefix + '.')
net.insert_child_to_cell(name, new_subcell) net.insert_child_to_cell(name, new_subcell)
change = True change = True
@ -124,10 +117,8 @@ class ConvertNetUtils():
self._convert_to_thor_net(subcell) self._convert_to_thor_net(subcell)
if isinstance(net, nn.SequentialCell) and change: if isinstance(net, nn.SequentialCell) and change:
print("is nn.SequentialCell and change")
net.cell_list = list(net.cells()) net.cell_list = list(net.cells())
def convert_to_thor_net(self, net): def convert_to_thor_net(self, net):
""" """
This interface is used to convert a network to thor layer network, in order to calculate and store the This interface is used to convert a network to thor layer network, in order to calculate and store the
@ -152,7 +143,7 @@ class ConvertNetUtils():
net.update_cell_type("second-order") net.update_cell_type("second-order")
class ConvertModelUtils(): class ConvertModelUtils:
""" """
Convert model to thor model. Convert model to thor model.
""" """
@ -203,7 +194,7 @@ class ConvertModelUtils():
... frequency=100) ... frequency=100)
>>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"}, >>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"},
... amp_level="O2", keep_batchnorm_fp32=False) ... amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt, >>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
... metrics={'acc'}, amp_level="O2", ... metrics={'acc'}, amp_level="O2",
... loss_scale_manager=loss_manager, ... loss_scale_manager=loss_manager,
... keep_batchnorm_fp32=False) ... keep_batchnorm_fp32=False)