diff --git a/mindspore/ccsrc/debug/debugger/proto_exporter.cc b/mindspore/ccsrc/debug/debugger/proto_exporter.cc index 3eb685ef4bd..ead91dc7692 100644 --- a/mindspore/ccsrc/debug/debugger/proto_exporter.cc +++ b/mindspore/ccsrc/debug/debugger/proto_exporter.cc @@ -33,7 +33,11 @@ namespace mindspore { -void CheckIfValidType(const TypePtr &type, debugger::TypeProto *type_proto) { +using TypeInfoToProtoTypeMap = std::vector>; + +void SetOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto); + +void CheckIfValidType(const TypePtr &type, debugger::TypeProto *const type_proto) { if (!(type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || type->isa() || @@ -45,8 +49,40 @@ void CheckIfValidType(const TypePtr &type, debugger::TypeProto *type_proto) { } } -void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, - debugger::TypeProto *type_proto) { +void SetTensorTypeProto(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) { + TypePtr elem_type = dyn_cast(type)->element(); + type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); + if (shape != nullptr && shape->isa()) { + abstract::ShapePtr shape_info = dyn_cast(shape); + for (const auto &elem : shape_info->shape()) { + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); + } + } +} + +void SetTupleTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) { + TuplePtr tuple_type = dyn_cast(type); + for (const auto &elem_type : tuple_type->elements()) { + SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); + } +} + +void SetListTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) { + ListPtr list_type = dyn_cast(type); + for (const auto &elem_type : list_type->elements()) { + SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); + } +} + +static TypeInfoToProtoTypeMap type_info_to_proto_type = { + {typeid(TensorType).name(), debugger::DT_TENSOR}, {typeid(Tuple).name(), debugger::DT_TUPLE}, + {typeid(TypeType).name(), debugger::DT_TYPE}, {typeid(List).name(), debugger::DT_LIST}, + {typeid(TypeAnything).name(), debugger::DT_ANYTHING}, {typeid(RefKeyType).name(), debugger::DT_REFKEY}, + {typeid(RefType).name(), debugger::DT_REF}, {typeid(Function).name(), debugger::DT_GRAPH}, + {typeid(TypeNone).name(), debugger::DT_NONE}, {typeid(String).name(), debugger::DT_STRING}, + {typeid(UMonadType).name(), debugger::DT_UMONAD}, {typeid(IOMonadType).name(), debugger::DT_IOMONAD}}; + +void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) { if (type_proto == nullptr) { return; } @@ -55,46 +91,22 @@ void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseSha return; } CheckIfValidType(type, type_proto); + for (auto &it : type_info_to_proto_type) { + if (type->IsFromTypeId(Base::GetTypeId(it.first))) { + type_proto->set_data_type(it.second); + break; + } + } if (type->isa()) { - TypePtr elem_type = dyn_cast(type)->element(); - type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); - type_proto->set_data_type(debugger::DT_TENSOR); - if (shape != nullptr && shape->isa()) { - abstract::ShapePtr shape_info = dyn_cast(shape); - for (const auto &elem : shape_info->shape()) { - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); - } - } - } else if (type->isa()) { - TuplePtr tuple_type = dyn_cast(type); - type_proto->set_data_type(debugger::DT_TUPLE); - for (const auto &elem_type : tuple_type->elements()) { - SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); - } - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_TYPE); - } else if (type->isa()) { - ListPtr list_type = dyn_cast(type); - type_proto->set_data_type(debugger::DT_LIST); - for (const auto &elem_type : list_type->elements()) { - SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); - } - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_ANYTHING); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_REFKEY); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_REF); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_GRAPH); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_NONE); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_STRING); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_UMONAD); - } else if (type->isa()) { - type_proto->set_data_type(debugger::DT_IOMONAD); + SetTensorTypeProto(type, shape, type_proto); + return; + } + if (type->isa()) { + SetTupleTypeProto(type, type_proto); + return; + } + if (type->isa()) { + SetListTypeProto(type, type_proto); } } @@ -102,7 +114,7 @@ void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger:: if (node == nullptr || type_proto == nullptr) { return; } - SetNodeOutputType(node->Type(), node->Shape(), type_proto); + SetOutputType(node->Type(), node->Shape(), type_proto); } void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) { diff --git a/mindspore/ccsrc/debug/debugger/proto_exporter.h b/mindspore/ccsrc/debug/debugger/proto_exporter.h index 6a7c4e26796..31839f1ccb7 100644 --- a/mindspore/ccsrc/debug/debugger/proto_exporter.h +++ b/mindspore/ccsrc/debug/debugger/proto_exporter.h @@ -48,7 +48,6 @@ class DebuggerProtoExporter { void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto); void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto); void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto); - void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto); void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto, LocDebugDumpMode dump_location = kDebugOff); diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index ff68fcfdeec..3cf7c52832c 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -386,10 +386,8 @@ static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(graph_obj); graph_obj->buffer() << "label=<"; - graph_obj->buffer() << ""; - graph_obj->buffer() << "" + << ""; - graph_obj->buffer() << "" + << ""; - graph_obj->buffer() << "" + << ""; - graph_obj->buffer() << "
"; - graph_obj->buffer() << ValueType(node); - graph_obj->buffer() << "
"; + graph_obj->buffer() << "
" << ValueType(node) << "
"; if (IsValueNode(node)) { graph_obj->buffer() << node->value()->cast()->name(); } else if (IsValueNode(node)) { @@ -404,18 +402,16 @@ static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { ValuePtr value = node->value(); if (value->isa()) { PrimitivePtr primitive = value->cast(); - graph_obj->buffer() << "
"; + graph_obj->buffer() << "
"; if (!primitive->instance_name().empty()) { graph_obj->buffer() << "instance name:" - << " "; - graph_obj->buffer() << primitive->instance_name(); - graph_obj->buffer() << "
"; + << " " << primitive->instance_name() << "
"; } auto attrs = primitive->attrs(); if (attrs.size() > 0) { - graph_obj->buffer() << "
"; + graph_obj->buffer() << "
"; int i = 0; for (const auto &attr : attrs) { if (i != 0) { @@ -432,8 +428,8 @@ static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { } } } - graph_obj->buffer() << "
>,"; + graph_obj->buffer() << "" + << ">,"; } static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc index ba61ffcaeed..2f8cd628bbc 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc @@ -32,7 +32,8 @@ #include "base/core_ops.h" #include "abstract/abstract_value.h" -namespace mindspore::pipeline { +namespace mindspore { +namespace pipeline { namespace { // namespace anonymous using ClassTypePtr = std::shared_ptr; @@ -1467,4 +1468,5 @@ bool ReAutoMonad(const FuncGraphPtr &func_graph) { } return changed; } -} // namespace mindspore::pipeline +} // namespace pipeline +} // namespace mindspore