!2285 [Code Review] code review fix

Merge pull request !2285 from jjfeing/master
This commit is contained in:
mindspore-ci-bot 2020-06-19 10:48:21 +08:00 committed by Gitee
commit a663f2066c
8 changed files with 186 additions and 136 deletions

View File

@ -37,9 +37,9 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
using device::ascend::ProfilingUtils; using device::ascend::ProfilingUtils;
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(kernel_graph);
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order(); const std::vector<CNodePtr> &origin_cnode_list = kernel_graph->execution_order();
std::vector<CNodePtr> momentum_list; std::vector<CNodePtr> momentum_list;
std::vector<CNodePtr> other_list; std::vector<CNodePtr> other_list;
for (const auto &cnode : origin_cnode_list) { for (const auto &cnode : origin_cnode_list) {
@ -52,7 +52,7 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g
std::vector<CNodePtr> new_order_list; std::vector<CNodePtr> new_order_list;
new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end());
new_order_list.insert(new_order_list.end(), momentum_list.begin(), momentum_list.end()); new_order_list.insert(new_order_list.end(), momentum_list.begin(), momentum_list.end());
kernel_graph_ptr->set_execution_order(new_order_list); kernel_graph->set_execution_order(new_order_list);
} }
void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {

View File

@ -51,7 +51,7 @@ class KernelAdjust {
static KernelAdjust instance; static KernelAdjust instance;
return instance; return instance;
} }
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
bool StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); bool StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr); void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr);

View File

@ -103,6 +103,7 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
} }
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::OutputAddrExist(kernel, index)) { if (AnfAlgo::OutputAddrExist(kernel, index)) {
return true; return true;
} }
@ -217,6 +218,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
auto device_address = auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
if (!ret) { if (!ret) {
MS_LOG(EXCEPTION) << "Malloc device memory failed."; MS_LOG(EXCEPTION) << "Malloc device memory failed.";
@ -618,6 +620,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>(); kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
input->addr = device_address->ptr_; input->addr = device_address->ptr_;

View File

@ -68,6 +68,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
} else if (flag == kDynamicMem) { } else if (flag == kDynamicMem) {
ptr = MallocDynamicMem(size, false); ptr = MallocDynamicMem(size, false);
} else if (flag == kReuseDynamicMem) { } else if (flag == kReuseDynamicMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
} }
return ptr; return ptr;
@ -75,6 +76,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) {
if (flag == kReuseDynamicMem) { if (flag == kReuseDynamicMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
} }
return MallocDynamicMem(size, false); return MallocDynamicMem(size, false);

View File

@ -15,15 +15,12 @@
*/ */
#include "kernel/tbe/tbe_kernel_build.h" #include "kernel/tbe/tbe_kernel_build.h"
#include <memory> #include <memory>
#include <map> #include <map>
#include <algorithm> #include <algorithm>
#include <unordered_set>
#include "operator/ops.h" #include "operator/ops.h"
#include "parallel/ops_info/ops_utils.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_mod.h"
#include "kernel/tbe/tbe_adapter.h" #include "kernel/tbe/tbe_adapter.h"
#include "kernel/tbe/tbe_python_funcs.h" #include "kernel/tbe/tbe_python_funcs.h"
#include "kernel/tbe/tbe_convert_utils.h" #include "kernel/tbe/tbe_convert_utils.h"
@ -37,6 +34,42 @@ constexpr auto kFusionOpList = "op_list";
constexpr auto kFusionKernelNamePrfix = "te_fusion"; constexpr auto kFusionKernelNamePrfix = "te_fusion";
constexpr auto kOptional = "optional_"; constexpr auto kOptional = "optional_";
constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z"; constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
constexpr auto kPlatform = "platform";
constexpr auto kPlatTBE = "TBE";
constexpr auto kGenModel = "gen_model";
constexpr auto kSingle = "single";
constexpr auto kImplPath = "impl_path";
constexpr auto kJInputs = "inputs";
constexpr auto kJOutputs = "outputs";
constexpr auto kJAttrs = "attrs";
constexpr auto kJKernelName = "kernel_name";
constexpr auto kJOpInfo = "op_info";
constexpr auto kJDtype = "dtype";
constexpr auto kJtype = "type";
constexpr auto kJName = "name";
constexpr auto kJOriShape = "ori_shape";
constexpr auto kJOriFormat = "ori_format";
constexpr auto kJShape = "shape";
constexpr auto kJFormat = "format";
constexpr auto kJValid = "valid";
constexpr auto kJParamType = "param_type";
constexpr auto kParamDynamic = "dynamic";
constexpr auto kParamRequred = "required";
constexpr auto kJDataType = "data_type";
constexpr auto kJOutputIndex = "output_index";
constexpr auto kJOutputDesc = "output_desc";
constexpr auto kJInputDesc = "input_desc";
constexpr auto kVTypeInt = "int";
constexpr auto kVTypeStr = "str";
constexpr auto kVTypeBool = "bool";
constexpr auto kVTypeFloat = "float";
constexpr auto kVTypeListInt = "listInt";
constexpr auto kVTypeInt32 = "Int32";
constexpr auto kVTypeListFloat = "listFloat";
constexpr auto kVTypeListListInt = "listListInt";
constexpr auto kJValue = "value";
constexpr auto kJDynIndex = "dyn_index";
constexpr auto kJFuncName = "func_name";
std::string NormalizeFullScopeName(const string &full_scope_name) { std::string NormalizeFullScopeName(const string &full_scope_name) {
// exp:Default/ReLU-op0 -->Default_ReLU_op0 // exp:Default/ReLU-op0 -->Default_ReLU_op0
@ -46,51 +79,51 @@ std::string NormalizeFullScopeName(const string &full_scope_name) {
return normal_ret; return normal_ret;
} }
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const shared_ptr<mindspore::AnfNode> &anf_node, bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
nlohmann::json *kernel_json) { nlohmann::json *kernel_json) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_json); MS_EXCEPTION_IF_NULL(kernel_json);
std::string op_name = AnfAlgo::GetCNodeName(anf_node); std::string op_name = AnfAlgo::GetCNodeName(anf_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
MS_EXCEPTION_IF_NULL(op_info_ptr); MS_EXCEPTION_IF_NULL(op_info_ptr);
(*kernel_json)["platform"] = "TBE"; (*kernel_json)[kPlatform] = kPlatTBE;
(*kernel_json)["gen_model"] = "single"; (*kernel_json)[kGenModel] = kSingle;
(*kernel_json)["impl_path"] = op_info_ptr->impl_path(); (*kernel_json)[kImplPath] = op_info_ptr->impl_path();
nlohmann::json op_info_json; nlohmann::json op_info_json;
if (op_info_ptr->impl_path().empty()) { if (op_info_ptr->impl_path().empty()) {
tbe::TbeAdapter::NormalizeFuncName(&op_name); tbe::TbeAdapter::NormalizeFuncName(&op_name);
} else { } else {
op_name = op_info_ptr->kernel_name(); op_name = op_info_ptr->kernel_name();
} }
op_info_json["name"] = op_name; op_info_json[kJName] = op_name;
// generate inputs json // generate inputs json
nlohmann::json inputs_json; nlohmann::json inputs_json;
if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) {
MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed"; MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed";
return false; return false;
} }
op_info_json["inputs"] = inputs_json; op_info_json[kJInputs] = inputs_json;
// generate outputs json // generate outputs json
nlohmann::json outputs_json; nlohmann::json outputs_json;
if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) { if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) {
MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed"; MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed";
return false; return false;
} }
op_info_json["outputs"] = outputs_json; op_info_json[kJOutputs] = outputs_json;
// generate attrs json // generate attrs json
nlohmann::json attrs_json; nlohmann::json attrs_json;
(void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json); (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
op_info_json["attrs"] = attrs_json; op_info_json[kJAttrs] = attrs_json;
std::string json_str = op_info_json.dump(); std::string json_str = op_info_json.dump();
size_t hash_id = std::hash<std::string>()(json_str); size_t hash_id = std::hash<std::string>()(json_str);
json_name_ = op_name + "_" + std::to_string(hash_id); json_name_ = op_name + "_" + std::to_string(hash_id);
json_info_ = json_str; json_info_ = json_str;
if (creater_type_ == PREBUILD) { if (creater_type_ == PREBUILD) {
op_info_json["kernel_name"] = NormalizeFullScopeName(anf_node->fullname_with_scope()); op_info_json[kJKernelName] = NormalizeFullScopeName(anf_node->fullname_with_scope());
} else { } else {
op_info_json["kernel_name"] = json_name_; op_info_json[kJKernelName] = json_name_;
} }
(*kernel_json)["op_info"] = op_info_json; (*kernel_json)[kJOpInfo] = op_info_json;
if (creater_type_ == SINGLE_BUILD) { if (creater_type_ == SINGLE_BUILD) {
TbeUtils::SaveJsonInfo(json_name_, json_info_); TbeUtils::SaveJsonInfo(json_name_, json_info_);
} }
@ -101,9 +134,10 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const shared_ptr<mindspore::An
return true; return true;
} }
bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value, bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
const shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
size_t input_i, vector<nlohmann::json> *input_list) { const string &op_input_name, size_t input_i,
std::vector<nlohmann::json> *input_list) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(input_ptr); MS_EXCEPTION_IF_NULL(input_ptr);
MS_EXCEPTION_IF_NULL(input_list); MS_EXCEPTION_IF_NULL(input_list);
@ -119,22 +153,22 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
ori_shape.emplace_back(1); ori_shape.emplace_back(1);
} }
nlohmann::json input_desc_json; nlohmann::json input_desc_json;
input_desc_json["dtype"] = dtype; input_desc_json[kJDtype] = dtype;
input_desc_json["name"] = op_input_name + std::to_string(input_i); input_desc_json[kJName] = op_input_name + std::to_string(input_i);
input_desc_json["ori_shape"] = ori_shape; input_desc_json[kJOriShape] = ori_shape;
input_desc_json["ori_format"] = kOpFormat_NCHW; input_desc_json[kJOriFormat] = kOpFormat_NCHW;
input_desc_json["shape"] = shape; input_desc_json[kJShape] = shape;
input_desc_json["format"] = format; input_desc_json[kJFormat] = format;
input_desc_json["valid"] = value; input_desc_json[kJValid] = value;
input_desc_json["param_type"] = input_ptr->param_type(); input_desc_json[kJParamType] = input_ptr->param_type();
input_list->emplace_back(input_desc_json); input_list->emplace_back(input_desc_json);
} }
return true; return true;
} }
bool TbeKernelJsonCreator::GenInputList(const shared_ptr<AnfNode> &anf_node, size_t input_tensor_num, bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr<AnfNode> &anf_node, size_t input_tensor_num,
const shared_ptr<OpIOInfo> &input_ptr, size_t *real_input_index, const std::shared_ptr<OpIOInfo> &input_ptr, size_t *real_input_index,
string *op_input_name, vector<nlohmann::json> *input_list) { string *op_input_name, std::vector<nlohmann::json> *input_list) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(input_ptr); MS_EXCEPTION_IF_NULL(input_ptr);
MS_EXCEPTION_IF_NULL(real_input_index); MS_EXCEPTION_IF_NULL(real_input_index);
@ -149,8 +183,8 @@ bool TbeKernelJsonCreator::GenInputList(const shared_ptr<AnfNode> &anf_node, siz
if (input_ptr->param_type() == "optional") { if (input_ptr->param_type() == "optional") {
*op_input_name = input_ptr->name() + "_optional_"; *op_input_name = input_ptr->name() + "_optional_";
nlohmann::json input_desc_json; nlohmann::json input_desc_json;
input_desc_json["valid"] = false; input_desc_json[kJValid] = false;
input_desc_json["name"] = *op_input_name + std::to_string(*real_input_index); input_desc_json[kJName] = *op_input_name + std::to_string(*real_input_index);
input_list->emplace_back(input_desc_json); input_list->emplace_back(input_desc_json);
continue; continue;
} }
@ -179,7 +213,7 @@ bool TbeKernelJsonCreator::GenInputList(const shared_ptr<AnfNode> &anf_node, siz
return true; return true;
} }
bool GetInputNameAndRealNum(const std::shared_ptr<AnfNode> &anf_node, const shared_ptr<OpIOInfo> &input_ptr, bool GetInputNameAndRealNum(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpIOInfo> &input_ptr,
size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) { size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(input_ptr); MS_EXCEPTION_IF_NULL(input_ptr);
@ -193,7 +227,7 @@ bool GetInputNameAndRealNum(const std::shared_ptr<AnfNode> &anf_node, const shar
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes)); dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
} }
if (input_ptr->param_type() == "dynamic") { if (input_ptr->param_type() == kParamDynamic) {
if (*dyn_input_index >= dyn_input_sizes.size()) { if (*dyn_input_index >= dyn_input_sizes.size()) {
MS_LOG(ERROR) << "dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size(); MS_LOG(ERROR) << "dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
return false; return false;
@ -259,9 +293,9 @@ bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr<AnfNode> &anf
return GenOutputDescJson(anf_node, outputs_ptr, outputs_json); return GenOutputDescJson(anf_node, outputs_ptr, outputs_json);
} }
bool TbeKernelJsonCreator::GenOutputDescJson(const shared_ptr<mindspore::AnfNode> &anf_node, bool TbeKernelJsonCreator::GenOutputDescJson(
const vector<shared_ptr<mindspore::kernel::OpIOInfo>> &outputs_ptr, const std::shared_ptr<mindspore::AnfNode> &anf_node,
nlohmann::json *outputs_json) { const std::vector<std::shared_ptr<mindspore::kernel::OpIOInfo>> &outputs_ptr, nlohmann::json *outputs_json) {
MS_EXCEPTION_IF_NULL(outputs_json); MS_EXCEPTION_IF_NULL(outputs_json);
size_t output_idx = 0; size_t output_idx = 0;
auto op_name = AnfAlgo::GetCNodeName(anf_node); auto op_name = AnfAlgo::GetCNodeName(anf_node);
@ -269,9 +303,9 @@ bool TbeKernelJsonCreator::GenOutputDescJson(const shared_ptr<mindspore::AnfNode
for (const auto &output_ptr : outputs_ptr) { for (const auto &output_ptr : outputs_ptr) {
size_t output_obj_num = 0; size_t output_obj_num = 0;
if (output_ptr->param_type() == "required") { if (output_ptr->param_type() == kParamRequred) {
output_obj_num = 1; output_obj_num = 1;
} else if (output_ptr->param_type() == "dynamic") { } else if (output_ptr->param_type() == kParamDynamic) {
if (outputs_ptr.size() > 1) { if (outputs_ptr.size() > 1) {
MS_LOG(ERROR) << "Dynamic output is unsupported multi output!"; MS_LOG(ERROR) << "Dynamic output is unsupported multi output!";
return false; return false;
@ -282,8 +316,8 @@ bool TbeKernelJsonCreator::GenOutputDescJson(const shared_ptr<mindspore::AnfNode
MS_LOG(INFO) << "op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none."; MS_LOG(INFO) << "op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none.";
std::vector<nlohmann::json> output_list; std::vector<nlohmann::json> output_list;
nlohmann::json output_obj; nlohmann::json output_obj;
output_obj["name"] = output_ptr->name(); output_obj[kJName] = output_ptr->name();
output_obj["valid"] = false; output_obj[kJValid] = false;
output_list.emplace_back(output_obj); output_list.emplace_back(output_obj);
(*outputs_json).push_back(output_list); (*outputs_json).push_back(output_list);
continue; continue;
@ -298,9 +332,9 @@ bool TbeKernelJsonCreator::GenOutputDescJson(const shared_ptr<mindspore::AnfNode
return true; return true;
} }
void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num, void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
const shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx, const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
vector<nlohmann::json> *output_list) { std::vector<nlohmann::json> *output_list) {
MS_EXCEPTION_IF_NULL(output_idx); MS_EXCEPTION_IF_NULL(output_idx);
MS_EXCEPTION_IF_NULL(output_list); MS_EXCEPTION_IF_NULL(output_list);
for (size_t i = 0; i < output_obj_num; i++) { for (size_t i = 0; i < output_obj_num; i++) {
@ -312,14 +346,14 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
ori_shape.emplace_back(1); ori_shape.emplace_back(1);
} }
nlohmann::json output_obj; nlohmann::json output_obj;
output_obj["dtype"] = dtype; output_obj[kJDtype] = dtype;
output_obj["shape"] = shape; output_obj[kJShape] = shape;
output_obj["format"] = format; output_obj[kJFormat] = format;
output_obj["ori_shape"] = ori_shape; output_obj[kJOriShape] = ori_shape;
output_obj["ori_format"] = kOpFormat_NCHW; output_obj[kJOriFormat] = kOpFormat_NCHW;
output_obj["name"] = output_ptr->name(); output_obj[kJName] = output_ptr->name();
output_obj["valid"] = true; output_obj[kJValid] = true;
output_obj["param_type"] = output_ptr->param_type(); output_obj[kJParamType] = output_ptr->param_type();
output_list->emplace_back(output_obj); output_list->emplace_back(output_obj);
(*output_idx)++; (*output_idx)++;
} }
@ -340,24 +374,24 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_no
for (const auto &attr_ptr : attrs_ptr) { for (const auto &attr_ptr : attrs_ptr) {
std::string attr_name = attr_ptr->name(); std::string attr_name = attr_ptr->name();
nlohmann::json attr_obj; nlohmann::json attr_obj;
attr_obj["name"] = attr_name; attr_obj[kJName] = attr_name;
if (op_name == "LayerNorm" && attr_obj["name"] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) { if (op_name == parallel::LAYER_NORM && attr_obj[kJName] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) {
continue; continue;
} }
if (primitive->GetAttr(attr_name) != nullptr) { if (primitive->GetAttr(attr_name) != nullptr) {
auto value = primitive->GetAttr(attr_name); auto value = primitive->GetAttr(attr_name);
std::string type = attr_ptr->type(); std::string type = attr_ptr->type();
ParseAttrValue(type, value, &attr_obj); ParseAttrValue(type, value, &attr_obj);
attr_obj["valid"] = true; attr_obj[kJValid] = true;
} else { } else {
if (op_info->impl_path().empty()) { if (op_info->impl_path().empty()) {
attr_obj["valid"] = false; attr_obj[kJValid] = false;
} else { } else {
if (attr_ptr->param_type() == "required" && creater_type_ == SINGLE_BUILD) { if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) {
MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name
<< " is required, but not set."; << " is required, but not set.";
} else { } else {
attr_obj["valid"] = false; attr_obj[kJValid] = false;
} }
} }
} }
@ -370,48 +404,48 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
nlohmann::json *attr_obj) { nlohmann::json *attr_obj) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
MS_EXCEPTION_IF_NULL(attr_obj); MS_EXCEPTION_IF_NULL(attr_obj);
if (type == "int") { if (type == kVTypeInt) {
auto attr_value = GetValue<int>(value); auto attr_value = GetValue<int>(value);
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else if (type == "str") { } else if (type == kVTypeStr) {
auto attr_value = GetValue<std::string>(value); auto attr_value = GetValue<std::string>(value);
if (attr_value == kOpFormat_FRAC_Z) { if (attr_value == kOpFormat_FRAC_Z) {
attr_value = kOpFormat_FRACTAL_Z; attr_value = kOpFormat_FRACTAL_Z;
} }
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else if (type == "bool") { } else if (type == kVTypeBool) {
auto attr_value = GetValue<bool>(value); auto attr_value = GetValue<bool>(value);
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else if (type == "float") { } else if (type == kVTypeFloat) {
auto attr_value = GetValue<float>(value); auto attr_value = GetValue<float>(value);
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else if (type == "listInt") { } else if (type == kVTypeListInt) {
std::vector<int> attr_value; std::vector<int> attr_value;
auto value_type = value->type(); auto value_type = value->type();
MS_EXCEPTION_IF_NULL(value_type); MS_EXCEPTION_IF_NULL(value_type);
auto value_type_str = value_type->ToString(); auto value_type_str = value_type->ToString();
if (value_type_str == "Int32") { if (value_type_str == kVTypeInt32) {
int data = GetValue<int>(value); int data = GetValue<int>(value);
attr_value.push_back(data); attr_value.push_back(data);
} else { } else {
attr_value = GetValue<std::vector<int>>(value); attr_value = GetValue<std::vector<int>>(value);
} }
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else if (type == "listFloat") { } else if (type == kVTypeListFloat) {
std::vector<float> attr_value; std::vector<float> attr_value;
auto value_type = value->type(); auto value_type = value->type();
MS_EXCEPTION_IF_NULL(value_type); MS_EXCEPTION_IF_NULL(value_type);
auto value_type_str = value_type->ToString(); auto value_type_str = value_type->ToString();
if (value_type_str == "float") { if (value_type_str == kVTypeFloat) {
auto data = GetValue<float>(value); auto data = GetValue<float>(value);
attr_value.push_back(data); attr_value.push_back(data);
} else { } else {
attr_value = GetValue<std::vector<float>>(value); attr_value = GetValue<std::vector<float>>(value);
} }
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else if (type == "listListInt") { } else if (type == kVTypeListListInt) {
auto attr_value = GetValue<std::vector<std::vector<int>>>(value); auto attr_value = GetValue<std::vector<std::vector<int>>>(value);
(*attr_obj)["value"] = attr_value; (*attr_obj)[kJValue] = attr_value;
} else { } else {
MS_LOG(EXCEPTION) << "type: " << type << "not support"; MS_LOG(EXCEPTION) << "type: " << type << "not support";
} }
@ -503,35 +537,35 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<si
} }
input_size_list->clear(); input_size_list->clear();
output_size_list->clear(); output_size_list->clear();
for (size_t i = 0; i < kernel_json["op_info"]["inputs"].size(); i++) { for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) {
for (size_t m = 0; m < kernel_json["op_info"]["inputs"][i].size(); m++) { for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) {
size_t size_i = 1; size_t size_i = 1;
if (kernel_json["op_info"]["inputs"][i][m]["valid"] == false) { if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) {
std::string input_name = kernel_json["op_info"]["inputs"][i][m]["name"]; std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName];
MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false."; MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
continue; continue;
} }
for (const auto &j : kernel_json["op_info"]["inputs"][i][m]["shape"]) { for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) {
size_i *= static_cast<size_t>(j); size_i *= static_cast<size_t>(j);
} }
std::string dtype = kernel_json["op_info"]["inputs"][i][m]["dtype"]; std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype); size_t nbyte = tbe::GetDtypeNbyte(dtype);
size_i *= nbyte; size_i *= nbyte;
input_size_list->push_back(size_i); input_size_list->push_back(size_i);
} }
} }
for (size_t i = 0; i < kernel_json["op_info"]["outputs"].size(); i++) { for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) {
for (size_t m = 0; m < kernel_json["op_info"]["outputs"][i].size(); m++) { for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) {
size_t size_i = 1; size_t size_i = 1;
if (kernel_json["op_info"]["outputs"][i][m]["valid"] == false) { if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) {
std::string output_name = kernel_json["op_info"]["outputs"][i][m]["name"]; std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName];
MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false."; MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
continue; continue;
} }
for (const auto &j : kernel_json["op_info"]["outputs"][i][m]["shape"]) { for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) {
size_i *= static_cast<size_t>(j); size_i *= static_cast<size_t>(j);
} }
std::string dtype = kernel_json["op_info"]["outputs"][i][m]["dtype"]; std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype); size_t nbyte = tbe::GetDtypeNbyte(dtype);
size_i *= nbyte; size_i *= nbyte;
output_size_list->push_back(size_i); output_size_list->push_back(size_i);
@ -540,9 +574,9 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<si
return true; return true;
} }
bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &input_nodes, bool TbeKernelBuild::GenFusionScopeJson(const std::vector<mindspore::AnfNodePtr> &input_nodes,
const vector<mindspore::AnfNodePtr> &compute_nodes, nlohmann::json *fusion_str, const std::vector<mindspore::AnfNodePtr> &compute_nodes,
std::string *fusion_kernel) { nlohmann::json *fusion_str, std::string *fusion_kernel) {
MS_EXCEPTION_IF_NULL(fusion_str); MS_EXCEPTION_IF_NULL(fusion_str);
MS_EXCEPTION_IF_NULL(fusion_kernel); MS_EXCEPTION_IF_NULL(fusion_kernel);
// get input layer info // get input layer info
@ -552,7 +586,7 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp
return false; return false;
} }
// gen fusion scopre_op jsom // gen fusion scopre_op jsom
vector<nlohmann::json> compute_list; std::vector<nlohmann::json> compute_list;
(*fusion_kernel) = kFusionKernelNamePrfix; (*fusion_kernel) = kFusionKernelNamePrfix;
// index: fusion build option input record, next one from 0 // index: fusion build option input record, next one from 0
static size_t index = 0; static size_t index = 0;
@ -565,7 +599,7 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp
} }
index = 0; index = 0;
// gen data input json // gen data input json
vector<nlohmann::json> data_list; std::vector<nlohmann::json> data_list;
for (const auto &layer : input_layers) { for (const auto &layer : input_layers) {
for (const auto &data_input : layer) { for (const auto &data_input : layer) {
nlohmann::json data_str; nlohmann::json data_str;
@ -588,51 +622,51 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_
if (node_out_idx > 0) { if (node_out_idx > 0) {
output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx);
} }
(*output_desc)["name"] = NormalizeFullScopeName(output_desc_name); (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name);
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx);
(*output_desc)["data_type"] = tbe::TypeIdToString(type_id); (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id);
auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx);
if (ori_shape.empty()) { if (ori_shape.empty()) {
ori_shape.emplace_back(1); ori_shape.emplace_back(1);
} }
(*output_desc)["ori_shape"] = ori_shape; (*output_desc)[kJOriShape] = ori_shape;
auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx);
if (shape.empty()) { if (shape.empty()) {
shape.emplace_back(1); shape.emplace_back(1);
} }
(*output_desc)["shape"] = shape; (*output_desc)[kJShape] = shape;
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
if (format == kOpFormat_DEFAULT) { if (format == kOpFormat_DEFAULT) {
format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND; format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND;
} }
(*output_desc)["format"] = format; (*output_desc)[kJFormat] = format;
(*output_desc)["ori_format"] = kOpFormat_NCHW; (*output_desc)[kJOriFormat] = kOpFormat_NCHW;
(*output_desc)["output_index"] = desc_output_idx; (*output_desc)[kJOutputIndex] = desc_output_idx;
if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) {
std::vector<size_t> spec_shape = {}; std::vector<size_t> spec_shape = {};
spec_shape.emplace_back(shape[0]); spec_shape.emplace_back(shape[0]);
spec_shape.emplace_back(shape[1]); spec_shape.emplace_back(shape[1]);
spec_shape.emplace_back(shape[2] * shape[3]); spec_shape.emplace_back(shape[2] * shape[3]);
spec_shape.emplace_back(shape[4]); spec_shape.emplace_back(shape[4]);
(*output_desc)["shape"] = spec_shape; (*output_desc)[kJShape] = spec_shape;
} else if (fusion_data_type == kFusionReLUGradV2) { } else if (fusion_data_type == kFusionReLUGradV2) {
std::vector<size_t> spec_shape = {}; std::vector<size_t> spec_shape = {};
spec_shape.emplace_back(shape[0]); spec_shape.emplace_back(shape[0]);
spec_shape.emplace_back(shape[1]); spec_shape.emplace_back(shape[1]);
spec_shape.emplace_back(shape[2] * shape[3]); spec_shape.emplace_back(shape[2] * shape[3]);
spec_shape.emplace_back(16); spec_shape.emplace_back(16);
(*output_desc)["shape"] = spec_shape; (*output_desc)[kJShape] = spec_shape;
(*output_desc)["data_type"] = "bool"; (*output_desc)[kJDataType] = kVTypeBool;
} }
} }
void TbeKernelBuild::GenReusedOutputDesc(const shared_ptr<mindspore::AnfNode> &anf_node, size_t index, void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
size_t output_index, nlohmann::json *output_desc) { size_t output_index, nlohmann::json *output_desc) {
std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index);
(*output_desc)["name"] = NormalizeFullScopeName(output_desc_name); (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name);
(*output_desc)["output_index"] = output_index; (*output_desc)[kJOutputIndex] = output_index;
std::vector<size_t> shape; std::vector<size_t> shape;
(*output_desc)["shape"] = shape; (*output_desc)[kJShape] = shape;
} }
bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name,
@ -657,6 +691,8 @@ bool TbeKernelBuild::GetInputLayers(const std::vector<mindspore::AnfNodePtr> &in
const std::vector<mindspore::AnfNodePtr> &compute_nodes, const std::vector<mindspore::AnfNodePtr> &compute_nodes,
std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers, std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers,
std::map<const AnfNodePtr, FusionDataType> *spec_data_input) { std::map<const AnfNodePtr, FusionDataType> *spec_data_input) {
MS_EXCEPTION_IF_NULL(input_layers);
MS_EXCEPTION_IF_NULL(spec_data_input);
auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) {
auto op_name = AnfAlgo::GetCNodeName(it); auto op_name = AnfAlgo::GetCNodeName(it);
return op_name == kConv2DBackpropInputOpName; return op_name == kConv2DBackpropInputOpName;
@ -712,10 +748,10 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::Anf
if (!data_input) { if (!data_input) {
MS_LOG(INFO) << "data input is optional node"; MS_LOG(INFO) << "data input is optional node";
auto name = std::string(kOptional) + std::to_string(*index); auto name = std::string(kOptional) + std::to_string(*index);
(*data_str)["name"] = name; (*data_str)[kJName] = name;
nlohmann::json output_desc; nlohmann::json output_desc;
output_desc["name"] = name; output_desc[kJName] = name;
output_desc["shape"] = "NULL"; output_desc[kJShape] = "NULL";
output_desc_list.push_back(output_desc); output_desc_list.push_back(output_desc);
(*index)++; (*index)++;
} else { } else {
@ -727,14 +763,14 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::Anf
auto real_node = kernel_idx.first; auto real_node = kernel_idx.first;
size_t real_idx = kernel_idx.second; size_t real_idx = kernel_idx.second;
MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx; MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx;
// "output_desc" // kJOutputDesc
nlohmann::json output_desc; nlohmann::json output_desc;
GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type); GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type);
output_desc_list.push_back(output_desc); output_desc_list.push_back(output_desc);
(*data_str)["name"] = NormalizeFullScopeName(real_node->fullname_with_scope()); (*data_str)[kJName] = NormalizeFullScopeName(real_node->fullname_with_scope());
} }
(*data_str)["output_desc"] = output_desc_list; (*data_str)[kJOutputDesc] = output_desc_list;
(*data_str)["type"] = "Data"; (*data_str)[kJtype] = "Data";
return true; return true;
} }
@ -765,6 +801,7 @@ bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) {
} }
size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) { size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) {
MS_EXCEPTION_IF_NULL(cnode);
if (is_dynamic_input) { if (is_dynamic_input) {
return 0; return 0;
} }
@ -779,8 +816,8 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
} }
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
static std::map<std::string, std::string> buffer_fussion_op_map = {{"DepthwiseConv2dNative", "DepthwiseConv2D"}, static std::map<std::string, std::string> buffer_fussion_op_map = {
{"TensorAdd", "Add"}}; {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}};
string result = origin_type; string result = origin_type;
auto iter = buffer_fussion_op_map.find(origin_type); auto iter = buffer_fussion_op_map.find(origin_type);
if (iter != buffer_fussion_op_map.end()) { if (iter != buffer_fussion_op_map.end()) {
@ -806,7 +843,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
GenDescJson(real_node, real_idx, real_idx, &input_desc); GenDescJson(real_node, real_idx, real_idx, &input_desc);
if (is_dynamic_input) { if (is_dynamic_input) {
MS_LOG(INFO) << "node has dynamic input."; MS_LOG(INFO) << "node has dynamic input.";
input_desc["dyn_index"] = (i - 1); input_desc[kJDynIndex] = (i - 1);
} }
input_desc_list_tmp.emplace_back(input_desc); input_desc_list_tmp.emplace_back(input_desc);
} }
@ -815,7 +852,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
MS_LOG(INFO) << "node has optional input."; MS_LOG(INFO) << "node has optional input.";
for (size_t i = 0; i < optional_num; ++i) { for (size_t i = 0; i < optional_num; ++i) {
nlohmann::json optional_input_desc; nlohmann::json optional_input_desc;
optional_input_desc["name"] = std::string(kOptional) + std::to_string(*index); optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index);
(*index)++; (*index)++;
(*layer_iter)->emplace_back(nullptr); (*layer_iter)->emplace_back(nullptr);
input_desc_list_tmp.emplace_back(optional_input_desc); input_desc_list_tmp.emplace_back(optional_input_desc);
@ -841,6 +878,7 @@ std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int> &o
bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode,
std::vector<nlohmann::json> *output_desc_list) { std::vector<nlohmann::json> *output_desc_list) {
MS_EXCEPTION_IF_NULL(output_desc_list);
auto output_size = AnfAlgo::GetOutputTensorNum(cnode); auto output_size = AnfAlgo::GetOutputTensorNum(cnode);
if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) {
auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrOutputUsedNum); auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrOutputUsedNum);
@ -883,22 +921,22 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n
// gen input desc // gen input desc
std::vector<nlohmann::json> input_desc_list; std::vector<nlohmann::json> input_desc_list;
(void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index);
(*compute_op_str)["input_desc"] = input_desc_list; (*compute_op_str)[kJInputDesc] = input_desc_list;
// gen output desc // gen output desc
std::vector<nlohmann::json> output_desc_list; std::vector<nlohmann::json> output_desc_list;
if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) {
MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope();
return false; return false;
} }
(*compute_op_str)["output_desc"] = output_desc_list; (*compute_op_str)[kJOutputDesc] = output_desc_list;
// gen others // gen others
auto origin_type = AnfAlgo::GetCNodeName(cnode); auto origin_type = AnfAlgo::GetCNodeName(cnode);
// replace special op type for buffer fusion op // replace special op type for buffer fusion op
auto type = GetRealOpType(origin_type); auto type = GetRealOpType(origin_type);
(*compute_op_str)["type"] = type; (*compute_op_str)[kJtype] = type;
tbe::TbeAdapter::NormalizeFuncName(&type); tbe::TbeAdapter::NormalizeFuncName(&type);
(*compute_op_str)["func_name"] = type; (*compute_op_str)[kJFuncName] = type;
(*compute_op_str)["name"] = NormalizeFullScopeName(cnode->fullname_with_scope()); (*compute_op_str)[kJName] = NormalizeFullScopeName(cnode->fullname_with_scope());
(void)(*fusion_kernel_name).append("_"); (void)(*fusion_kernel_name).append("_");
(void)(*fusion_kernel_name).append(type); (void)(*fusion_kernel_name).append(type);
return true; return true;
@ -906,16 +944,17 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n
size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) { size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) {
size_t ret = 1; size_t ret = 1;
for (const auto &shape_item : desc["shape"]) { for (const auto &shape_item : desc[kJShape]) {
ret *= static_cast<size_t>(shape_item); ret *= static_cast<size_t>(shape_item);
} }
std::string data_type = desc["data_type"]; std::string data_type = desc[kJDataType];
size_t nbyte = tbe::GetDtypeNbyte(data_type); size_t nbyte = tbe::GetDtypeNbyte(data_type);
ret *= nbyte; ret *= nbyte;
return ret; return ret;
} }
bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vector<mindspore::AnfNodePtr> &output_nodes, bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list,
const std::vector<mindspore::AnfNodePtr> &output_nodes,
std::vector<size_t> *input_size_list, std::vector<size_t> *output_size_list) { std::vector<size_t> *input_size_list, std::vector<size_t> *output_size_list) {
MS_EXCEPTION_IF_NULL(input_size_list); MS_EXCEPTION_IF_NULL(input_size_list);
MS_EXCEPTION_IF_NULL(output_size_list); MS_EXCEPTION_IF_NULL(output_size_list);
@ -923,15 +962,15 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto
output_size_list->clear(); output_size_list->clear();
for (const auto &op : fusion_op_list) { for (const auto &op : fusion_op_list) {
if (op["type"] == "Data") { if (op[kJtype] == "Data") {
const auto &data_output_desc = op["output_desc"]; const auto &data_output_desc = op[kJOutputDesc];
for (const auto &data_output : data_output_desc) { for (const auto &data_output : data_output_desc) {
if (data_output["shape"] == "NULL") { if (data_output[kJShape] == "NULL") {
break; break;
} }
auto ret = GetIOSizeImpl(data_output); auto ret = GetIOSizeImpl(data_output);
input_size_list->push_back(ret); input_size_list->push_back(ret);
MS_LOG(INFO) << "Fusion info: scope input name " << op["name"] << ", size: " << ret; MS_LOG(INFO) << "Fusion info: scope input name " << op[kJName] << ", size: " << ret;
} }
} }
} }
@ -943,13 +982,13 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto
auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope()); auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope());
MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx; MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx;
for (const auto &op : fusion_op_list) { for (const auto &op : fusion_op_list) {
if (op["name"] == normal_name) { if (op[kJName] == normal_name) {
auto op_output_desces = op["output_desc"]; auto op_output_desces = op[kJOutputDesc];
if (output_node != real_node) { if (output_node != real_node) {
// tuple_get item // tuple_get item
MS_LOG(INFO) << "output is a tuple getitem node"; MS_LOG(INFO) << "output is a tuple getitem node";
auto output_desc = op_output_desces[real_idx]; auto output_desc = op_output_desces[real_idx];
if (output_desc["shape"].empty()) { if (output_desc[kJShape].empty()) {
MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx; MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
return false; return false;
} }
@ -958,7 +997,7 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto
MS_LOG(INFO) << "Fusion info: scope output index " << real_idx << ", size: " << ret; MS_LOG(INFO) << "Fusion info: scope output index " << real_idx << ", size: " << ret;
} else { } else {
for (const auto &output_desc : op_output_desces) { for (const auto &output_desc : op_output_desces) {
if (output_desc["shape"].empty()) { if (output_desc[kJShape].empty()) {
MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output"; MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output";
continue; continue;
} }

View File

@ -47,6 +47,7 @@ bool TbeOpParallelPreBuild(const std::vector<AnfNodePtr> &anf_nodes) {
MS_EXCEPTION_IF_NULL(build_manger); MS_EXCEPTION_IF_NULL(build_manger);
for (const auto &anf_node : anf_nodes) { for (const auto &anf_node : anf_nodes) {
// gen kernel json // gen kernel json
MS_EXCEPTION_IF_NULL(anf_node);
nlohmann::json kernel_json; nlohmann::json kernel_json;
TbeKernelJsonCreator creator(OP_PRE_COMPILE); TbeKernelJsonCreator creator(OP_PRE_COMPILE);
if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) {

View File

@ -224,6 +224,9 @@ constexpr char PACK[] = "Pack";
constexpr char GATHER_ND[] = "GatherNd"; constexpr char GATHER_ND[] = "GatherNd";
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char ADD[] = "Add";
// Parallel don't care // Parallel don't care
constexpr char TUPLE_GETITEM[] = "tuple_getitem"; constexpr char TUPLE_GETITEM[] = "tuple_getitem";

View File

@ -258,6 +258,7 @@ void MemReuseUtil::SetKernelDefMap() {
void MemReuseUtil::SetKernelDefInputs() { void MemReuseUtil::SetKernelDefInputs() {
for (const auto &kernel : graph_->execution_order()) { for (const auto &kernel : graph_->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
auto key = kernel.get(); auto key = kernel.get();
// find kernel_def according to cnode addr // find kernel_def according to cnode addr
auto iter = kernel_map_.find(key); auto iter = kernel_map_.find(key);
@ -366,6 +367,7 @@ void MemReuseUtil::SetGraphOutputRefCount() {
void MemReuseUtil::ResetDynamicUsedRefCount() { void MemReuseUtil::ResetDynamicUsedRefCount() {
for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) {
for (auto &ref_count : iter->second) { for (auto &ref_count : iter->second) {
MS_EXCEPTION_IF_NULL(ref_count);
ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; ref_count->ref_count_dynamic_use_ = ref_count->ref_count_;
} }
} }