forked from mindspore-Ecosystem/mindspore
Decouple AkgKernelJsonGenerator from MS backend (step 1)
* move GetInputTensorValue from common_utils to json_generator * get dtype size by `Number.nbits()` instead of `GetDtypeNbyte` map. * manually get attr from anfnode, instead of `AnfAlgo::GetNodeAttr` * replace `AnfAlgo::GetCNodePrimitive` with `GetCNodePrimitive` in anf.cc * it's not used to judge `AnfAlgo::IsRealKernel` in inner function. cleancode jobs: * remove the `Clean` function from AkgKernelJsonGenerator * delete the json key "id", to delete the mutex in AkgKernelJsonGenerator
This commit is contained in:
parent
748184c8e0
commit
05f0bd950f
|
@ -27,13 +27,10 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using kernel::GetDtypeNbyte;
|
||||
using kernel::GetInputIndex;
|
||||
using kernel::GetInputTensorValue;
|
||||
using kernel::GetKernelInput;
|
||||
using kernel::GetOutputIndex;
|
||||
using kernel::GetStrProcessorFromContext;
|
||||
using kernel::kProcessorCuda;
|
||||
using kernel::OpAttr;
|
||||
using kernel::OpImplyType;
|
||||
using kernel::OpInfo;
|
||||
|
@ -41,7 +38,7 @@ using kernel::OpIOInfo;
|
|||
namespace {
|
||||
std::vector<int> GetDynInputSize(const AnfNodePtr &anf_node) {
|
||||
std::vector<int> dyn_input_sizes;
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
auto primitive = GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (primitive->HasAttr(kAttrDynInputSizes)) {
|
||||
std::vector<int64_t> dyn_input_sizes_me =
|
||||
|
@ -145,16 +142,6 @@ class OpInfoExtractor {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
int AkgKernelJsonGenerator::op_cnt_ = 0;
|
||||
std::mutex AkgKernelJsonGenerator::op_cnt_mtx_;
|
||||
|
||||
int AkgKernelJsonGenerator::GetOpCntInc() {
|
||||
op_cnt_mtx_.lock();
|
||||
int cnt = op_cnt_++;
|
||||
op_cnt_mtx_.unlock();
|
||||
return cnt;
|
||||
}
|
||||
|
||||
TypeId AkgKernelJsonGenerator::GetInputDataType(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index)
|
||||
: AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
|
||||
|
@ -183,6 +170,68 @@ std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf_node,
|
|||
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(anf_node, index);
|
||||
}
|
||||
|
||||
bool AkgKernelJsonGenerator::GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
nlohmann::json *node_json) const {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(node_json);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->size()) {
|
||||
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
|
||||
<< cnode->inputs().size() << "][" << cnode->DebugString() << "]";
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(input_idx + 1);
|
||||
if (!IsValueNode<tensor::Tensor>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(DEBUG) << "Value of input node is nullptr, op: [" << input_node->DebugString() << "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type_id = tensor->data_type();
|
||||
auto *data = tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (tensor->DataSize() > 1) {
|
||||
// not const tensor.
|
||||
MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type_id == kFloat64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<double *>(data)[0];
|
||||
} else if (type_id == kFloat32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<float *>(data)[0];
|
||||
} else if (type_id == kFloat16->type_id()) {
|
||||
float16 *val = static_cast<float16 *>(data);
|
||||
(*node_json)["value"] = static_cast<float>(val[0]);
|
||||
} else if (type_id == kUInt64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint64_t *>(data)[0];
|
||||
} else if (type_id == kUInt32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint32_t *>(data)[0];
|
||||
} else if (type_id == kUInt16->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint16_t *>(data)[0];
|
||||
} else if (type_id == kUInt8->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint8_t *>(data)[0];
|
||||
} else if (type_id == kInt64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int64_t *>(data)[0];
|
||||
} else if (type_id == kInt32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int32_t *>(data)[0];
|
||||
} else if (type_id == kInt16->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int16_t *>(data)[0];
|
||||
} else if (type_id == kInt8->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int8_t *>(data)[0];
|
||||
} else if (type_id == kBool->type_id()) {
|
||||
(*node_json)["value"] = static_cast<bool *>(data)[0];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info,
|
||||
nlohmann::json *inputs_json) {
|
||||
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
|
||||
|
@ -321,7 +370,7 @@ bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, cons
|
|||
return true;
|
||||
}
|
||||
auto dyn_input_sizes = GetDynInputSize(anf_node);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
auto primitive = GetCNodePrimitive(anf_node);
|
||||
|
||||
// create input name list for "x_shape" in attr with "x" in primitive.
|
||||
auto inputs = op_info->inputs_ptr();
|
||||
|
@ -481,7 +530,7 @@ bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node
|
|||
|
||||
// get basic params from currentNodeOpDesc
|
||||
if (IsPrimitiveCNode(anf_node, prim::kPrimCustom)) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
auto primitive = GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
(*node_json)[kJsonKeyName] = primitive->name();
|
||||
} else {
|
||||
|
@ -518,6 +567,17 @@ bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node
|
|||
return true;
|
||||
}
|
||||
|
||||
size_t AkgKernelJsonGenerator::GetTensorSize(const nlohmann::json &node_json) const {
|
||||
const std::vector<size_t> &shape = node_json[kJsonKeyShape];
|
||||
const std::string &dtype = node_json[kJsonKeyDataType];
|
||||
auto type_ptr = StringToType(dtype);
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto num_ptr = type_ptr->cast<NumberPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_ptr);
|
||||
size_t nbyte = IntToSize(num_ptr->nbits() / static_cast<int>(BitsNum::eBits8));
|
||||
return std::accumulate(shape.begin(), shape.end(), nbyte, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
bool AkgKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size,
|
||||
std::vector<size_t> *output_size) const {
|
||||
if (input_size == nullptr || output_size == nullptr) {
|
||||
|
@ -529,22 +589,12 @@ bool AkgKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vec
|
|||
|
||||
for (size_t i = 0; i < node_json[kJsonKeyInputDesc].size(); i++) {
|
||||
for (size_t m = 0; m < node_json[kJsonKeyInputDesc][i].size(); m++) {
|
||||
std::string dtype = node_json[kJsonKeyInputDesc][i][m][kJsonKeyDataType];
|
||||
size_t nbyte = GetDtypeNbyte(dtype);
|
||||
size_t size_i =
|
||||
std::accumulate(node_json[kJsonKeyInputDesc][i][m][kJsonKeyShape].begin(),
|
||||
node_json[kJsonKeyInputDesc][i][m][kJsonKeyShape].end(), nbyte, std::multiplies<size_t>());
|
||||
input_size->push_back(size_i);
|
||||
input_size->push_back(GetTensorSize(node_json[kJsonKeyInputDesc][i][m]));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < node_json[kJsonKeyOutputDesc].size(); i++) {
|
||||
std::string dtype = node_json[kJsonKeyOutputDesc][i][kJsonKeyDataType];
|
||||
size_t nbyte = GetDtypeNbyte(dtype);
|
||||
size_t size_i =
|
||||
std::accumulate(node_json[kJsonKeyOutputDesc][i][kJsonKeyShape].begin(),
|
||||
node_json[kJsonKeyOutputDesc][i][kJsonKeyShape].end(), nbyte, std::multiplies<size_t>());
|
||||
output_size->push_back(size_i);
|
||||
output_size->push_back(GetTensorSize(node_json[kJsonKeyOutputDesc][i]));
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -564,7 +614,7 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
|
|||
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
|
||||
kernel_name_ = op_name + "_";
|
||||
(void)kernel_name_.append(std::to_string(hash_id));
|
||||
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
|
||||
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||
(*kernel_json)[kJsonKeyProcess] = GetStrProcessorFromContext(); // GetProcessorStr(anf_node);
|
||||
|
@ -588,8 +638,10 @@ void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_no
|
|||
nlohmann::json *kernel_json) {
|
||||
std::vector<std::string> stitchs;
|
||||
for (auto const &anf_node : anf_nodes) {
|
||||
if (AnfAlgo::HasNodeAttr(kAttrStitch, anf_node->cast<CNodePtr>()) &&
|
||||
AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrStitch) == "common") {
|
||||
auto prim = GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto stitch_attr = prim->GetAttr(kAttrStitch);
|
||||
if (stitch_attr != nullptr && GetValue<std::string>(stitch_attr) == "common") {
|
||||
auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0});
|
||||
if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) {
|
||||
stitchs.emplace_back(name);
|
||||
|
@ -656,7 +708,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
static_cast<void>(kernel_name_.append(fg_name).append("_"));
|
||||
}
|
||||
static_cast<void>(kernel_name_.append(std::to_string(hash_id)));
|
||||
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
|
||||
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||
(*kernel_json)[kJsonKeyProcess] = GetStrProcessorFromContext();
|
||||
|
@ -679,18 +731,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes,
|
||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
|
||||
for (auto const &anf_node : anf_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!AnfAlgo::IsRealKernel(anf_node)) {
|
||||
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
|
||||
return false;
|
||||
}
|
||||
nlohmann::json node_json;
|
||||
if (!GenerateSingleKernelJson(anf_node, &node_json)) {
|
||||
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
auto primitive = GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
(*node_json_map)[anf_node] = node_json;
|
||||
|
@ -770,9 +817,11 @@ void AkgKernelJsonGenerator::GenParallelJson(const std::vector<AnfNodePtr> &anf_
|
|||
if (tcnode == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(tcnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// Get dim info.
|
||||
if (AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) {
|
||||
auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo);
|
||||
if (prim->HasAttr(kAttrParallelDimInfo)) {
|
||||
auto info = GetValue<std::vector<size_t>>(prim->GetAttr(kAttrParallelDimInfo));
|
||||
if (info.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Parallel dim info is invalid!";
|
||||
}
|
||||
|
@ -782,22 +831,17 @@ void AkgKernelJsonGenerator::GenParallelJson(const std::vector<AnfNodePtr> &anf_
|
|||
sub_graphs_info[info[0]].first = info[1];
|
||||
}
|
||||
// Get fusion type.
|
||||
if (AnfAlgo::HasNodeAttr(kAttrParallelFusionType, tcnode)) {
|
||||
fusion_type = AnfAlgo::GetNodeAttr<std::string>(tcnode, kAttrParallelFusionType);
|
||||
if (prim->HasAttr(kAttrParallelFusionType)) {
|
||||
fusion_type = GetValue<std::string>(prim->GetAttr(kAttrParallelFusionType));
|
||||
}
|
||||
// Get fusion type info.
|
||||
if (AnfAlgo::HasNodeAttr(kAttrParallelTypeInfo, tcnode)) {
|
||||
type_info = AnfAlgo::GetNodeAttr<std::vector<std::vector<int>>>(tcnode, kAttrParallelTypeInfo);
|
||||
if (prim->HasAttr(kAttrParallelTypeInfo)) {
|
||||
type_info = GetValue<std::vector<std::vector<int>>>(prim->GetAttr(kAttrParallelTypeInfo));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!sub_graphs_info.empty()) {
|
||||
auto processor = GetStrProcessorFromContext(); // GetProcessorStr(anf_nodes[0]);
|
||||
if (processor != kProcessorCuda) {
|
||||
MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now.";
|
||||
}
|
||||
|
||||
nlohmann::json parallel_fusion_json;
|
||||
parallel_fusion_json[kJsonKeyFusionType] = fusion_type;
|
||||
parallel_fusion_json[kJsonKeyTypeInfo] = type_info;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "backend/kernel_compiler/oplib/opinfo.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -63,8 +63,6 @@ constexpr auto kJsonKeyStitchOp = "stitch_op";
|
|||
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
|
||||
constexpr auto kJsonKeyComputeCapability = "compute_capability";
|
||||
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
||||
// dump option
|
||||
struct DumpOption {
|
||||
bool is_before_select_kernel = false;
|
||||
|
@ -91,8 +89,8 @@ class ComputeCapability {
|
|||
|
||||
class AkgKernelJsonGenerator {
|
||||
public:
|
||||
AkgKernelJsonGenerator() { Clear(); }
|
||||
explicit AkgKernelJsonGenerator(DumpOption dump_option) : dump_option_(dump_option) { Clear(); }
|
||||
AkgKernelJsonGenerator() = default;
|
||||
explicit AkgKernelJsonGenerator(DumpOption dump_option) : dump_option_(std::move(dump_option)) {}
|
||||
~AkgKernelJsonGenerator() = default;
|
||||
|
||||
bool CollectJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json);
|
||||
|
@ -107,11 +105,6 @@ class AkgKernelJsonGenerator {
|
|||
std::string kernel_json_str() const { return kernel_json_.dump(); }
|
||||
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
|
||||
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
|
||||
void Clear() {
|
||||
input_tensor_idx_.clear();
|
||||
address_node_map_.clear();
|
||||
output_tensor_idx_ = 0;
|
||||
}
|
||||
void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; }
|
||||
std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; }
|
||||
|
||||
|
@ -132,8 +125,6 @@ class AkgKernelJsonGenerator {
|
|||
nlohmann::json CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list, const nlohmann::json &inputs_json,
|
||||
const std::map<AnfNodePtr, nlohmann::json> &node_json_map);
|
||||
|
||||
int GetOpCntInc();
|
||||
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
|
||||
size_t GetOutputTensorIdxInc();
|
||||
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
|
||||
|
@ -152,14 +143,13 @@ class AkgKernelJsonGenerator {
|
|||
void GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list,
|
||||
const std::map<AnfNodePtr, nlohmann::json> &node_json_map, nlohmann::json *kernel_json);
|
||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *node_json) const;
|
||||
size_t GetTensorSize(const nlohmann::json &node_json) const;
|
||||
|
||||
DumpOption dump_option_;
|
||||
static int op_cnt_;
|
||||
// lock for variable fusionOpCnt in singleton mode
|
||||
static std::mutex op_cnt_mtx_;
|
||||
std::string kernel_name_;
|
||||
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
|
||||
size_t output_tensor_idx_;
|
||||
size_t output_tensor_idx_{0};
|
||||
nlohmann::json kernel_json_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -714,67 +714,6 @@ void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNode
|
|||
}
|
||||
}
|
||||
|
||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(node_json);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->size()) {
|
||||
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
|
||||
<< cnode->inputs().size() << "][" << cnode->DebugString() << "]";
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(input_idx + 1);
|
||||
if (!IsValueNode<tensor::Tensor>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(DEBUG) << "Value of input node is nullptr, op: [" << input_node->DebugString() << "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type_id = tensor->data_type();
|
||||
auto *data = tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (tensor->DataSize() > 1) {
|
||||
// not const tensor.
|
||||
MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type_id == kFloat64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<double *>(data)[0];
|
||||
} else if (type_id == kFloat32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<float *>(data)[0];
|
||||
} else if (type_id == kFloat16->type_id()) {
|
||||
float16 *val = static_cast<float16 *>(data);
|
||||
(*node_json)["value"] = static_cast<float>(val[0]);
|
||||
} else if (type_id == kUInt64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint64_t *>(data)[0];
|
||||
} else if (type_id == kUInt32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint32_t *>(data)[0];
|
||||
} else if (type_id == kUInt16->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint16_t *>(data)[0];
|
||||
} else if (type_id == kUInt8->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint8_t *>(data)[0];
|
||||
} else if (type_id == kInt64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int64_t *>(data)[0];
|
||||
} else if (type_id == kInt32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int32_t *>(data)[0];
|
||||
} else if (type_id == kInt16->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int16_t *>(data)[0];
|
||||
} else if (type_id == kInt8->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int8_t *>(data)[0];
|
||||
} else if (type_id == kBool->type_id()) {
|
||||
(*node_json)["value"] = static_cast<bool *>(data)[0];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsWeightBoundary(const AnfNodePtr &node) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
|
|
|
@ -98,7 +98,6 @@ void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr>
|
|||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
||||
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
|
||||
void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list);
|
||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
|
||||
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
|
||||
bool IsWeightBoundary(const AnfNodePtr &node);
|
||||
std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode);
|
||||
|
|
Loading…
Reference in New Issue