diff --git a/example/bert_clue/run_pretrain.py b/example/bert_clue/run_pretrain.py index 2209176d6b5..6b8127ddaca 100644 --- a/example/bert_clue/run_pretrain.py +++ b/example/bert_clue/run_pretrain.py @@ -25,6 +25,7 @@ from mindspore.train.model import Model from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig +from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR from dataset import create_bert_dataset diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 655e1dcacd0..a314668c950 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -40,6 +40,7 @@ enum MatchCountPriority : int { MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, MATCH_FORMAT_COUNT, MATCH_SPECIAL_FORMAT_COUNT, + MATCH_DEFAULT_FORMAT_COUNT, MATCH_OUTPUT_DTYPE_COUNT, MATCH_COUNT_PRIORITY_END }; @@ -73,7 +74,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); if (AnfAlgo::IsFeatureMapInput(cnode, index) && kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) { - priority_matched_format = !is_init ? priority_matched_format : pre_output_format; + priority_matched_format = !is_init ? pre_output_format : priority_matched_format; is_init = true; } // feature map has two or more special format; @@ -83,7 +84,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); } - if (need_change_nd) { + if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { priority_matched_format = kOpFormat_DEFAULT; } AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); @@ -134,6 +135,9 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; } + if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { + (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; + } } for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { @@ -410,10 +414,10 @@ std::shared_ptr ChooseMatchedKernelInfo( if (kernel_info_list.empty()) { return nullptr; } - std::vector most_match_counts = {-1, -1, -1, -1}; + std::vector most_match_counts = {-1, -1, -1, -1, -1}; size_t selected_index = 0; for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector cur_kernel_info_match_counts = {0, 0, 0, 0}; + std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; auto kernel_build_info = *(kernel_info_list[info_index]); std::shared_ptr kernel_info_ptr = kernel_info_list[info_index]; UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 220b309200d..eca719347c8 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -89,8 +89,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ops/_op_impl/tbe/tanh.py b/mindspore/ops/_op_impl/tbe/tanh.py index 3d0b2704a31..c5b1caf1dd6 100644 --- a/mindspore/ops/_op_impl/tbe/tanh.py +++ b/mindspore/ops/_op_impl/tbe/tanh.py @@ -29,6 +29,8 @@ tanh_op_info = TBERegOp("Tanh") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ .get_op_info() diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py index ec9b711cd9f..5eb1f40f87b 100644 --- a/tests/st/networks/models/bert/bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -170,8 +170,8 @@ def test_bert_tdt(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, - 12.40294, 12.621653] + expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299, + 12.403329, 12.621632] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)