diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 347d63be39f..a7c8d131fb9 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -35,6 +35,7 @@ enum MatchCountPriority : int { MATCH_COUNT_PRIORITY_BEGIN = 0, MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, MATCH_FORMAT_COUNT, + MATCH_SPECIAL_FORMAT_COUNT, MATCH_5D_FORMAT_COUNT, MATCH_OUTPUT_DTYPE_COUNT, MATCH_COUNT_PRIORITY_END @@ -81,6 +82,12 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return true; }; + if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" || + AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") { + if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) { + return true; + } + } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); @@ -154,7 +161,7 @@ bool PriorityChooseItem(const std::vector &cur_item, std::vector *best return false; } } - return false; + return true; } void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, @@ -174,12 +181,11 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons continue; } } - if (input_anf_node->isa()) { - if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { - continue; - } - } if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { + if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && + kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) { + (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++; + } (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; } if (kernel_build_info.GetInputDeviceType(input_index) == @@ -203,7 +209,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; } } -} +} // namespace void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ee1eeaddfce..ea5e969e524 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -195,6 +195,9 @@ const std::set kOptOperatorSet = { kApplyRMSPropOpName, }; +const std::set kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, + kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; + static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { if (access(file_name.c_str(), F_OK) != 0) { MS_LOG(DEBUG) << "File `" << file_name << "` does not exist."; diff --git a/mindspore/ops/_op_impl/tbe/gelu.py b/mindspore/ops/_op_impl/tbe/gelu.py index 80933125470..171d97c0437 100644 --- a/mindspore/ops/_op_impl/tbe/gelu.py +++ b/mindspore/ops/_op_impl/tbe/gelu.py @@ -32,10 +32,10 @@ from mindspore.ops.op_info_register import op_info_register { "index": 0, "dtype": [ - "float16","float","float16","float16","float16","float16","float","float","float","float" + "float16","float","float16","float","float16","float16","float16","float16","float","float","float","float" ], "format": [ - "FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" + "FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" ], "name": "x", "need_compile": false, @@ -47,10 +47,10 @@ from mindspore.ops.op_info_register import op_info_register { "index": 0, "dtype": [ - "float16","float","float16","float16","float16","float16","float","float","float","float" + "float16","float","float16","float","float16","float16","float16","float16","float","float","float","float" ], "format": [ - "FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" + "FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" ], "name": "y", "need_compile": true, diff --git a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py index 7c50707fbdb..6f3ffc7dadd 100644 --- a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py @@ -153,8 +153,7 @@ def test_bert_tdt(): batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4, - end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False) + optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads.set_train(True) model = Model(netwithgrads) @@ -178,10 +177,10 @@ def test_bert_tdt(): param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False) loss_value = np.array(parallel_callback.loss_list) - expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849, - 10.616718, 10.486609] + expect_out = [12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594, + 12.824818, 12.38842, 12.604046] logger.info("expected loss value output: {}".format(expect_out)) - assert allclose(loss_value, expect_out, 0.001, 0.001) + assert allclose(loss_value, expect_out, 0.00001, 0.00001) if __name__ == '__main__': test_bert_tdt()