add DeviceType

This commit is contained in:
zhujingxuan 2021-04-20 20:50:47 +08:00
parent 409584a138
commit 71a20c7353
3 changed files with 5 additions and 0 deletions

View File

@ -230,6 +230,7 @@ constexpr auto kSpliceContext = "context";
constexpr auto kSpliceForwardIndexes = "forward_indexes";
constexpr auto kSpliceOutputDims = "output_dim";
constexpr auto kSideEffectIO = "side_effect_io";
constexpr auto kDeviceType = "device_type";
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};

View File

@ -74,6 +74,7 @@ table CNode {
inputIndex: [uint];
outputIndex: [uint];
quantType: QuantType = QUANT_NONE;
deviceType: int = -1; // 1 = CPU, 2 = GPU, 3 = NPU, -1 = UNKNOWN
}
table SubGraph {

View File

@ -26,6 +26,7 @@
#include "tools/common/tensor_util.h"
#include "abstract/abstract_value.h"
#include "mindspore/core/ir/primitive.h"
#include "mindspore/core/ops/op_utils.h"
#include "ops/fusion/partial_fusion.h"
#include "ops/depend.h"
#include "ops/make_tuple.h"
@ -341,6 +342,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
}
node->name = cnode->fullname_with_scope();
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
node->deviceType = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : -1;
ret = SetOpInputNode(cnode, meta_graphT, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpInputNode failed";