forked from mindspore-Ecosystem/mindspore
add DeviceType
This commit is contained in:
parent
409584a138
commit
71a20c7353
|
@ -230,6 +230,7 @@ constexpr auto kSpliceContext = "context";
|
|||
constexpr auto kSpliceForwardIndexes = "forward_indexes";
|
||||
constexpr auto kSpliceOutputDims = "output_dim";
|
||||
constexpr auto kSideEffectIO = "side_effect_io";
|
||||
constexpr auto kDeviceType = "device_type";
|
||||
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ table CNode {
|
|||
inputIndex: [uint];
|
||||
outputIndex: [uint];
|
||||
quantType: QuantType = QUANT_NONE;
|
||||
deviceType: int = -1; // 1 = CPU, 2 = GPU, 3 = NPU, -1 = UNKNOWN
|
||||
}
|
||||
|
||||
table SubGraph {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "mindspore/core/ir/primitive.h"
|
||||
#include "mindspore/core/ops/op_utils.h"
|
||||
#include "ops/fusion/partial_fusion.h"
|
||||
#include "ops/depend.h"
|
||||
#include "ops/make_tuple.h"
|
||||
|
@ -341,6 +342,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
}
|
||||
node->name = cnode->fullname_with_scope();
|
||||
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
|
||||
auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
|
||||
node->deviceType = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : -1;
|
||||
ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOpInputNode failed";
|
||||
|
|
Loading…
Reference in New Issue