forked from mindspore-Ecosystem/mindspore
parse core type from pre-build for tbe kernel compile
This commit is contained in:
parent
a90ee15937
commit
2b1b539b36
|
@ -601,6 +601,28 @@ void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const kernel::Fu
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AnfRuntimeAlgorithm::SetCoreType(const AnfNodePtr &node, const std::string &core_type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto builder =
|
||||||
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||||
|
MS_EXCEPTION_IF_NULL(builder);
|
||||||
|
builder->SetCoreType(core_type);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AnfRuntimeAlgorithm::GetCoreType(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||||
|
if (kernel_info == nullptr) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
auto build_info = kernel_info->select_kernel_build_info();
|
||||||
|
if (build_info == nullptr) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return build_info->core_type();
|
||||||
|
}
|
||||||
|
|
||||||
void AnfRuntimeAlgorithm::SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc) {
|
void AnfRuntimeAlgorithm::SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto builder =
|
auto builder =
|
||||||
|
|
|
@ -147,6 +147,9 @@ class AnfRuntimeAlgorithm {
|
||||||
static void SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type);
|
static void SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type);
|
||||||
static void SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc);
|
static void SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc);
|
||||||
static std::vector<nlohmann::json> GetOutputDataDesc(const AnfNodePtr &node);
|
static std::vector<nlohmann::json> GetOutputDataDesc(const AnfNodePtr &node);
|
||||||
|
// core type
|
||||||
|
static void SetCoreType(const AnfNodePtr &node, const std::string &core_type);
|
||||||
|
static std::string GetCoreType(const AnfNodePtr &node);
|
||||||
// set select kernel_build_info
|
// set select kernel_build_info
|
||||||
static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
|
static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
|
||||||
// get select kernel_build_info
|
// get select kernel_build_info
|
||||||
|
|
|
@ -185,6 +185,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_ty
|
||||||
kernel_build_info_->fusion_type_ = fusion_type;
|
kernel_build_info_->fusion_type_ = fusion_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetCoreType(const std::string &core_type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||||
|
kernel_build_info_->core_type_ = core_type;
|
||||||
|
}
|
||||||
|
|
||||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc) {
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||||
kernel_build_info_->output_data_desc_ = data_desc;
|
kernel_build_info_->output_data_desc_ = data_desc;
|
||||||
|
|
|
@ -36,6 +36,7 @@ class KernelBuildInfo {
|
||||||
fusion_type_ = OPAQUE;
|
fusion_type_ = OPAQUE;
|
||||||
processor_ = AICORE;
|
processor_ = AICORE;
|
||||||
op_pattern_ = kCommonPattern;
|
op_pattern_ = kCommonPattern;
|
||||||
|
core_type_ = "";
|
||||||
input_reshape_type_ = {};
|
input_reshape_type_ = {};
|
||||||
output_reshape_type_ = {};
|
output_reshape_type_ = {};
|
||||||
origin_data_format_ = kOpFormat_DEFAULT;
|
origin_data_format_ = kOpFormat_DEFAULT;
|
||||||
|
@ -82,6 +83,8 @@ class KernelBuildInfo {
|
||||||
|
|
||||||
std::vector<std::string> GetAllInputReshapeType() const;
|
std::vector<std::string> GetAllInputReshapeType() const;
|
||||||
|
|
||||||
|
std::string core_type() const { return core_type_; }
|
||||||
|
|
||||||
OpPattern op_pattern() const { return op_pattern_; }
|
OpPattern op_pattern() const { return op_pattern_; }
|
||||||
|
|
||||||
std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
|
std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
|
||||||
|
@ -109,6 +112,7 @@ class KernelBuildInfo {
|
||||||
private:
|
private:
|
||||||
KernelType kernel_type_;
|
KernelType kernel_type_;
|
||||||
std::string origin_data_format_;
|
std::string origin_data_format_;
|
||||||
|
std::string core_type_;
|
||||||
std::vector<std::string> inputs_format_;
|
std::vector<std::string> inputs_format_;
|
||||||
OpPattern op_pattern_;
|
OpPattern op_pattern_;
|
||||||
std::vector<std::string> outputs_format_;
|
std::vector<std::string> outputs_format_;
|
||||||
|
@ -133,6 +137,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
||||||
SetFusionType(kernel_build_info->fusion_type());
|
SetFusionType(kernel_build_info->fusion_type());
|
||||||
SetProcessor(kernel_build_info->processor());
|
SetProcessor(kernel_build_info->processor());
|
||||||
SetOpPattern(kernel_build_info->op_pattern());
|
SetOpPattern(kernel_build_info->op_pattern());
|
||||||
|
SetCoreType(kernel_build_info->core_type());
|
||||||
|
SetOutputDataDesc(kernel_build_info->output_data_desc());
|
||||||
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
|
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
|
||||||
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
|
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
|
||||||
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
|
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
|
||||||
|
@ -166,6 +172,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
||||||
|
|
||||||
void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
|
void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
|
||||||
|
|
||||||
|
void SetCoreType(const std::string &core_type);
|
||||||
|
|
||||||
void SetFusionType(FusionType fusion_type);
|
void SetFusionType(FusionType fusion_type);
|
||||||
// save prebuild result
|
// save prebuild result
|
||||||
void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
|
void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
|
||||||
|
|
|
@ -27,15 +27,17 @@ namespace kernel {
|
||||||
* @brief fuse op and return a callable mod
|
* @brief fuse op and return a callable mod
|
||||||
*/
|
*/
|
||||||
struct FusionScopeInfo {
|
struct FusionScopeInfo {
|
||||||
FusionScopeInfo(int64_t id, std::string f_name, std::vector<AnfNodePtr> in, std::vector<AnfNodePtr> comp,
|
FusionScopeInfo(int64_t id, std::string f_name, std::string core_type, std::vector<AnfNodePtr> in,
|
||||||
std::vector<AnfNodePtr> out)
|
std::vector<AnfNodePtr> comp, std::vector<AnfNodePtr> out)
|
||||||
: scope_id(id),
|
: scope_id(id),
|
||||||
full_name(f_name),
|
full_name(f_name),
|
||||||
|
core_type(core_type),
|
||||||
input_nodes(std::move(in)),
|
input_nodes(std::move(in)),
|
||||||
compute_nodes(std::move(comp)),
|
compute_nodes(std::move(comp)),
|
||||||
output_nodes(std::move(out)) {}
|
output_nodes(std::move(out)) {}
|
||||||
int64_t scope_id{};
|
int64_t scope_id{};
|
||||||
std::string full_name{};
|
std::string full_name{};
|
||||||
|
std::string core_type{};
|
||||||
std::vector<AnfNodePtr> input_nodes;
|
std::vector<AnfNodePtr> input_nodes;
|
||||||
std::vector<AnfNodePtr> compute_nodes;
|
std::vector<AnfNodePtr> compute_nodes;
|
||||||
std::vector<AnfNodePtr> output_nodes;
|
std::vector<AnfNodePtr> output_nodes;
|
||||||
|
|
|
@ -171,7 +171,10 @@ void AiCoreDynamicKernel::ComputeTiling() {
|
||||||
tiling::OpTilingCalculateAdapter converter;
|
tiling::OpTilingCalculateAdapter converter;
|
||||||
ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>("default");
|
ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>("default");
|
||||||
auto ge_node = converter.AnfNodeToGeNodeAdapter(cnode, &ge_graph, depend_tensor_map_, op_compile_info_);
|
auto ge_node = converter.AnfNodeToGeNodeAdapter(cnode, &ge_graph, depend_tensor_map_, op_compile_info_);
|
||||||
(void)optiling::OpParaCalculateV2(ge_node, op_run_info_v2);
|
auto ret = optiling::OpParaCalculateV2(ge_node, op_run_info_v2);
|
||||||
|
if (ret != ::ge::GRAPH_SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "Compute tiling failed!";
|
||||||
|
}
|
||||||
|
|
||||||
block_dim_ = op_run_info_v2.GetBlockDim();
|
block_dim_ = op_run_info_v2.GetBlockDim();
|
||||||
op_run_info_v2.GetAllWorkspaces(workspaces_size_);
|
op_run_info_v2.GetAllWorkspaces(workspaces_size_);
|
||||||
|
|
|
@ -123,7 +123,10 @@ void DynamicTbeKernelMod::InitOp() {
|
||||||
device::tiling::OpTilingCalculateAdapter converter;
|
device::tiling::OpTilingCalculateAdapter converter;
|
||||||
::ge::ComputeGraphPtr ge_graph = std::make_shared<::ge::ComputeGraph>("default");
|
::ge::ComputeGraphPtr ge_graph = std::make_shared<::ge::ComputeGraph>("default");
|
||||||
auto ge_node = converter.AnfNodeToGeNodeAdapter(cnode, &ge_graph, depend_tensor_map_, op_compile_info_);
|
auto ge_node = converter.AnfNodeToGeNodeAdapter(cnode, &ge_graph, depend_tensor_map_, op_compile_info_);
|
||||||
(void)optiling::OpParaCalculateV2(ge_node, op_run_info_v2);
|
auto ret = optiling::OpParaCalculateV2(ge_node, op_run_info_v2);
|
||||||
|
if (ret != ::ge::GRAPH_SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "Compute tiling failed!";
|
||||||
|
}
|
||||||
|
|
||||||
block_dim_ = op_run_info_v2.GetBlockDim();
|
block_dim_ = op_run_info_v2.GetBlockDim();
|
||||||
std::vector<int64_t> workspace_size_list;
|
std::vector<int64_t> workspace_size_list;
|
||||||
|
|
|
@ -33,9 +33,9 @@ namespace mindspore::kernel {
|
||||||
using mindspore::kernel::tbe::TbeAdapter;
|
using mindspore::kernel::tbe::TbeAdapter;
|
||||||
bool FusionBuildTbeJsonCreator::GenJson(const FusionScopeInfo &fusion_scope_info, nlohmann::json *fusion_json) {
|
bool FusionBuildTbeJsonCreator::GenJson(const FusionScopeInfo &fusion_scope_info, nlohmann::json *fusion_json) {
|
||||||
MS_EXCEPTION_IF_NULL(fusion_json);
|
MS_EXCEPTION_IF_NULL(fusion_json);
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Start Generate Fusion Json, Fusion Node: " << fusion_scope_info.full_name;
|
MS_LOG(DEBUG) << "Start Generate Fusion Json, Fusion Node: " << fusion_scope_info.full_name;
|
||||||
nlohmann::json soc_info_json = kernel::tbe::TbeUtils::GenSocInfo();
|
nlohmann::json soc_info_json = kernel::tbe::TbeUtils::GenSocInfo();
|
||||||
|
soc_info_json[kJCoreType] = fusion_scope_info.core_type;
|
||||||
(*fusion_json)[kJSocInfo] = soc_info_json;
|
(*fusion_json)[kJSocInfo] = soc_info_json;
|
||||||
|
|
||||||
std::vector<nlohmann::json> op_list_json;
|
std::vector<nlohmann::json> op_list_json;
|
||||||
|
@ -44,7 +44,6 @@ bool FusionBuildTbeJsonCreator::GenJson(const FusionScopeInfo &fusion_scope_info
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
(*fusion_json)[kJOpList] = op_list_json;
|
(*fusion_json)[kJOpList] = op_list_json;
|
||||||
|
|
||||||
GenFusionOpName(fusion_json, kJFusionKernelNamePrefix);
|
GenFusionOpName(fusion_json, kJFusionKernelNamePrefix);
|
||||||
AddOpNameForComputeNode(fusion_json);
|
AddOpNameForComputeNode(fusion_json);
|
||||||
(*fusion_json)[kJL1Size] = -1;
|
(*fusion_json)[kJL1Size] = -1;
|
||||||
|
|
|
@ -42,6 +42,8 @@ bool SingleTbeJsonCreator::GenJson(const AnfNodePtr &anf_node, nlohmann::json *k
|
||||||
MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate op_list json failed";
|
MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate op_list json failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
auto core_type = AnfAlgo::GetCoreType(anf_node);
|
||||||
|
soc_info_json[kJCoreType] = core_type;
|
||||||
(*kernel_json)[kJSocInfo] = soc_info_json;
|
(*kernel_json)[kJSocInfo] = soc_info_json;
|
||||||
(*kernel_json)[kJOpList] = op_list;
|
(*kernel_json)[kJOpList] = op_list;
|
||||||
GenFusionOpName(kernel_json);
|
GenFusionOpName(kernel_json);
|
||||||
|
|
|
@ -83,6 +83,7 @@ constexpr auto kJPyModulePath = "py_module_path";
|
||||||
constexpr auto kJAttrs = "attrs";
|
constexpr auto kJAttrs = "attrs";
|
||||||
constexpr auto kJAttrDesc = "attr_desc";
|
constexpr auto kJAttrDesc = "attr_desc";
|
||||||
constexpr auto kJSocInfo = "SocInfo";
|
constexpr auto kJSocInfo = "SocInfo";
|
||||||
|
constexpr auto kJCoreType = "coreType";
|
||||||
constexpr auto kJFusionOpName = "fusion_op_name";
|
constexpr auto kJFusionOpName = "fusion_op_name";
|
||||||
constexpr auto kJGraphID = "graph_id";
|
constexpr auto kJGraphID = "graph_id";
|
||||||
constexpr auto kJType = "type";
|
constexpr auto kJType = "type";
|
||||||
|
|
|
@ -400,11 +400,13 @@ void TbeKernelCompileManager::SavePreBuildResult(int32_t task_id, const std::str
|
||||||
auto op_pattern = GetJsonValue<std::string>(result, "op_pattern");
|
auto op_pattern = GetJsonValue<std::string>(result, "op_pattern");
|
||||||
auto fusion_type = kernel::GetFusionTypeByName(op_pattern);
|
auto fusion_type = kernel::GetFusionTypeByName(op_pattern);
|
||||||
auto output_data_desc = GetJsonValue<nlohmann::json>(result, "op_params");
|
auto output_data_desc = GetJsonValue<nlohmann::json>(result, "op_params");
|
||||||
|
auto core_type = GetJsonValue<nlohmann::json>(result, "core_type");
|
||||||
auto json_name = task_iter->second.json_name;
|
auto json_name = task_iter->second.json_name;
|
||||||
// save pre build result
|
// save pre build result
|
||||||
struct PreBuildResult pre_res;
|
struct PreBuildResult pre_res;
|
||||||
pre_res.json_name = json_name;
|
pre_res.json_name = json_name;
|
||||||
pre_res.fusion_type = fusion_type;
|
pre_res.fusion_type = fusion_type;
|
||||||
|
pre_res.core_type = core_type;
|
||||||
pre_res.output_data_desc = output_data_desc;
|
pre_res.output_data_desc = output_data_desc;
|
||||||
prebuild_res_map_[json_name] = pre_res;
|
prebuild_res_map_[json_name] = pre_res;
|
||||||
}
|
}
|
||||||
|
@ -584,6 +586,8 @@ void TbeKernelCompileManager::UpdateFusionTypeAndOutputDataDesc(const std::vecto
|
||||||
auto pre_res = prebuild_res_map_[kernel_name];
|
auto pre_res = prebuild_res_map_[kernel_name];
|
||||||
auto fusion_type = pre_res.fusion_type;
|
auto fusion_type = pre_res.fusion_type;
|
||||||
auto output_data_desc = pre_res.output_data_desc;
|
auto output_data_desc = pre_res.output_data_desc;
|
||||||
|
auto core_type = pre_res.core_type;
|
||||||
|
AnfAlgo::SetCoreType(node, core_type);
|
||||||
AnfAlgo::SetFusionType(node, fusion_type);
|
AnfAlgo::SetFusionType(node, fusion_type);
|
||||||
AnfAlgo::SetOutputDataDesc(node, {output_data_desc});
|
AnfAlgo::SetOutputDataDesc(node, {output_data_desc});
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,6 +50,7 @@ struct TaskInfo {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PreBuildResult {
|
struct PreBuildResult {
|
||||||
|
std::string core_type;
|
||||||
std::string json_name;
|
std::string json_name;
|
||||||
kernel::FusionType fusion_type;
|
kernel::FusionType fusion_type;
|
||||||
nlohmann::json output_data_desc;
|
nlohmann::json output_data_desc;
|
||||||
|
|
|
@ -49,6 +49,7 @@ using FusedNodeRecord = std::vector<mindspore::HashSet<AnfNodePtr>>;
|
||||||
|
|
||||||
struct BufferFusionInfo_t {
|
struct BufferFusionInfo_t {
|
||||||
std::string full_name;
|
std::string full_name;
|
||||||
|
std::string core_type;
|
||||||
std::vector<AnfNodePtr> anf_nodes;
|
std::vector<AnfNodePtr> anf_nodes;
|
||||||
std::vector<AnfNodePtr> inputs_list;
|
std::vector<AnfNodePtr> inputs_list;
|
||||||
std::vector<AnfNodePtr> outputs_list;
|
std::vector<AnfNodePtr> outputs_list;
|
||||||
|
|
|
@ -305,6 +305,8 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
||||||
fusion_info.all_outputs_from_last_node = true;
|
fusion_info.all_outputs_from_last_node = true;
|
||||||
for (size_t node_idx = 0; node_idx < fusion_info.anf_nodes.size(); ++node_idx) {
|
for (size_t node_idx = 0; node_idx < fusion_info.anf_nodes.size(); ++node_idx) {
|
||||||
const auto &node = fusion_info.anf_nodes[node_idx];
|
const auto &node = fusion_info.anf_nodes[node_idx];
|
||||||
|
auto core_type = AnfAlgo::GetCoreType(node);
|
||||||
|
fusion_info.core_type = core_type;
|
||||||
size_t old_output_num = fusion_info.outputs_list.size();
|
size_t old_output_num = fusion_info.outputs_list.size();
|
||||||
if (common::AnfAlgo::GetOutputTensorNum(node) == 1) {
|
if (common::AnfAlgo::GetOutputTensorNum(node) == 1) {
|
||||||
auto use_nodes = manager->node_users()[node];
|
auto use_nodes = manager->node_users()[node];
|
||||||
|
@ -471,7 +473,7 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
|
||||||
std::transform(buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos),
|
std::transform(buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos),
|
||||||
[](const auto &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo {
|
[](const auto &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo {
|
||||||
return mindspore::kernel::FusionScopeInfo(
|
return mindspore::kernel::FusionScopeInfo(
|
||||||
buffer_fusion_info.first, buffer_fusion_info.second.full_name,
|
buffer_fusion_info.first, buffer_fusion_info.second.full_name, buffer_fusion_info.second.core_type,
|
||||||
buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.anf_nodes,
|
buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.anf_nodes,
|
||||||
buffer_fusion_info.second.outputs_list);
|
buffer_fusion_info.second.outputs_list);
|
||||||
});
|
});
|
||||||
|
|
|
@ -114,7 +114,7 @@ RangePair PaddingRangeTo5D(const RangePair &ori_range) {
|
||||||
dst_range[W_ncdhw] = ori_range[H_ncdhw];
|
dst_range[W_ncdhw] = ori_range[H_ncdhw];
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << ori_range.size();
|
MS_LOG(EXCEPTION) << "Unexpected shape size: " << ori_range.size();
|
||||||
}
|
}
|
||||||
return dst_range;
|
return dst_range;
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,7 +355,7 @@ std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
|
||||||
shape_5d[W_ncdhw] = shape[H_ncdhw];
|
shape_5d[W_ncdhw] = shape[H_ncdhw];
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
MS_LOG(EXCEPTION) << "Unexpected shape :" << shape;
|
||||||
}
|
}
|
||||||
return shape_5d;
|
return shape_5d;
|
||||||
}
|
}
|
||||||
|
@ -385,7 +385,7 @@ std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
|
||||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
MS_LOG(EXCEPTION) << "Unexpected shape : " << shape;
|
||||||
}
|
}
|
||||||
return shape_4d;
|
return shape_4d;
|
||||||
}
|
}
|
||||||
|
|
|
@ -393,6 +393,7 @@ class TbeJobManager:
|
||||||
pre_compile_result = dict()
|
pre_compile_result = dict()
|
||||||
pre_compile_result["op_pattern"] = target_job.result
|
pre_compile_result["op_pattern"] = target_job.result
|
||||||
pre_compile_result["op_params"] = op_params
|
pre_compile_result["op_params"] = op_params
|
||||||
|
pre_compile_result["core_type"] = new_job["core_type"]
|
||||||
target_job.result = json.dumps(pre_compile_result)
|
target_job.result = json.dumps(pre_compile_result)
|
||||||
target_job.info("Query result:{}".format(new_job["result"]))
|
target_job.info("Query result:{}".format(new_job["result"]))
|
||||||
if new_job["status_code"] == 0:
|
if new_job["status_code"] == 0:
|
||||||
|
|
|
@ -304,7 +304,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FusionScopeInfo fusion_scope_info(0, full_name, input_nodes, compute_nodes, {});
|
FusionScopeInfo fusion_scope_info(0, full_name, "", input_nodes, compute_nodes, {});
|
||||||
nlohmann::json fusion_json;
|
nlohmann::json fusion_json;
|
||||||
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
||||||
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
||||||
|
@ -364,7 +364,7 @@ TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FusionScopeInfo fusion_scope_info(0, full_name, input_nodes, compute_nodes, {});
|
FusionScopeInfo fusion_scope_info(0, full_name, "", input_nodes, compute_nodes, {});
|
||||||
nlohmann::json fusion_json;
|
nlohmann::json fusion_json;
|
||||||
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
||||||
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
||||||
|
|
Loading…
Reference in New Issue