diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 2f3c57c665c..04869ff1534 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -195,6 +195,7 @@ void RunOpAscendDataLayout(const std::shared_ptr &kernel_g auto data_layout_pm = std::make_shared("pynative_transop_pm"); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); @@ -338,7 +339,9 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index f2740095377..67cd3dd4aca 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -529,13 +529,6 @@ bool TransDataType(const TypeIdArgs &args, void *result) { } bool TransFormat(const FormatArgs &args, void *result) { - using FormatTransfer = std::function; - const std::map format_trans_map{ - {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, - {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, - {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, - {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; - MS_LOG(DEBUG) << "Start trans format."; if (abstract::TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; @@ -544,15 +537,14 @@ bool TransFormat(const FormatArgs &args, void *result) { if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { return NchwTo4D(args, result); } - auto iter = format_trans_map.find(args.device_format); - if (iter == format_trans_map.end()) { + auto iter = kTransFormatMapOfHostToDevice.find(args.device_format); + if (iter == kTransFormatMapOfHostToDevice.end()) { MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; } return iter->second(args, result); } bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { - using FormatTransfer = std::function; const std::map format_trans_map{ {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 3275d9e364a..153b773f268 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -76,6 +76,13 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); +using FormatTransfer = std::function; +const std::map kTransFormatMapOfHostToDevice{ + {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, + {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, + {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, + {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; + } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 18cf9177d1c..a134a6165c3 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -29,6 +29,7 @@ #include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" #include "backend/session/anf_runtime_algorithm.h" +#include "common/trans.h" #include "debug/anf_ir_dump.h" #include "frontend/operator/ops.h" #include "utils/ms_context.h" @@ -382,14 +383,15 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { continue; } auto builder = std::make_shared(); - std::vector output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; + auto refresh_format = selected_kernel_info->GetInputFormat(input_index); + std::vector output_format = {refresh_format}; + // if not find in host convert format map means the host has not registered the convert function of this format + if (trans::kTransFormatMapOfHostToDevice.find(refresh_format) == trans::kTransFormatMapOfHostToDevice.end() && + refresh_format != kOpFormat_DEFAULT) { + output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; + } if (IsValueNode(input_kernel_node) && AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { - if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || - selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || - selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { - output_format = {selected_kernel_info->GetInputFormat(input_index)}; - } builder->SetOutputsFormat(output_format); std::vector output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; builder->SetOutputsDeviceType(output_type); @@ -397,11 +399,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { continue; } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { - if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || - selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || - selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { - output_format = {selected_kernel_info->GetInputFormat(input_index)}; - } builder->SetOutputsFormat(output_format); std::vector output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; builder->SetOutputsDeviceType(output_type);