From 71a20c73535ada514f1e938963f2458bf571c37a Mon Sep 17 00:00:00 2001 From: zhujingxuan Date: Tue, 20 Apr 2021 20:50:47 +0800 Subject: [PATCH] add DeviceType --- mindspore/core/ops/op_utils.h | 1 + mindspore/lite/schema/model.fbs | 1 + mindspore/lite/tools/anf_exporter/anf_exporter.cc | 3 +++ 3 files changed, 5 insertions(+) diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 0d3f5923eaf..6f667cc9afd 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -230,6 +230,7 @@ constexpr auto kSpliceContext = "context"; constexpr auto kSpliceForwardIndexes = "forward_indexes"; constexpr auto kSpliceOutputDims = "output_dim"; constexpr auto kSideEffectIO = "side_effect_io"; +constexpr auto kDeviceType = "device_type"; const std::set common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 76d60c404c6..22769fcc72d 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -74,6 +74,7 @@ table CNode { inputIndex: [uint]; outputIndex: [uint]; quantType: QuantType = QUANT_NONE; + deviceType: int = -1; // 1 = CPU, 2 = GPU, 3 = NPU, -1 = UNKNOWN } table SubGraph { diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 233e7047219..8003bb7cbb0 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -26,6 +26,7 @@ #include "tools/common/tensor_util.h" #include "abstract/abstract_value.h" #include "mindspore/core/ir/primitive.h" +#include "mindspore/core/ops/op_utils.h" #include "ops/fusion/partial_fusion.h" #include "ops/depend.h" #include "ops/make_tuple.h" @@ -341,6 +342,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname = cnode->fullname_with_scope(); node->primitive = std::unique_ptr(primT); + auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType); + node->deviceType = (device_type_attr != nullptr) ? GetValue(device_type_attr) : -1; ret = SetOpInputNode(cnode, meta_graphT, node.get()); if (ret != RET_OK) { MS_LOG(ERROR) << "SetOpInputNode failed";