From 23a298ca811f6d631f182cc33dbbaf2ddb9f450e Mon Sep 17 00:00:00 2001 From: liubuyu Date: Tue, 8 Sep 2020 08:51:29 +0800 Subject: [PATCH] support new format frac_zn_lstm --- .../ascend/ascend_backend_optimization.cc | 4 + .../backend/optimizer/ascend/ascend_helper.cc | 14 +++ .../insert_transpose_for_basiclstm_op.cc | 100 ++++++++++++++++++ .../insert_transpose_for_basiclstm_op.h | 39 +++++++ .../ascend/ir_fission/transdata_split.cc | 6 +- mindspore/ccsrc/common/trans.cc | 13 ++- mindspore/ccsrc/utils/utils.h | 16 +-- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py | 6 +- .../tbe/basic_lstm_cell_input_grad.py | 2 +- .../tbe/basic_lstm_cell_weight_grad.py | 2 +- mindspore/ops/_op_impl/tbe/trans_data.py | 4 + mindspore/ops/op_info_register.py | 2 + 12 files changed, 194 insertions(+), 14 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 3c4074c89e6..5caa575d138 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -58,6 +58,7 @@ #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" @@ -284,6 +285,9 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index b4eb70c7269..5a0a42ab74e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -142,6 +142,15 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An return node; } +void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) { + MS_EXCEPTION_IF_NULL(node); + if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get()); + } +} + AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(func_graph); @@ -149,6 +158,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; auto kernel_graph = func_graph->cast(); size_t out_num = AnfAlgo::GetOutputTensorNum(node); + std::string op_name; + if (node->isa()) { + op_name = AnfAlgo::GetCNodeName(node); + } for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); if (output_format == kOpFormat_NC1KHKWHWC0) { @@ -159,6 +172,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); + ReFreshInferShape(trans_op, op_name); if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.cc new file mode 100644 index 00000000000..365bacd1fff --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" +#include +#include +#include "utils/utils.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertTranspose::DefinePattern() const { + std::shared_ptr V = std::make_shared(UnVisited); + std::shared_ptr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::string &op_name) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + auto kernel_graph = func_graph->cast(); + CNodePtr new_node = nullptr; + + std::vector transpose_inputs; + auto prim = std::make_shared(prim::kPrimTranspose->name()); + transpose_inputs.push_back(NewValueNode(prim)); + + if (op_name == kBasicLSTMCellInputGradOpName) { + auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 1); + auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); + auto dst_shape = {origin_shape[1], origin_shape[0]}; + transpose_inputs.push_back(AnfAlgo::GetInputNode(cnode, 1)); + CNodePtr transpose = func_graph->NewCNode(transpose_inputs); + MS_EXCEPTION_IF_NULL(transpose); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {dst_shape}, transpose.get()); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{1, 0}), transpose); + AnfAlgo::SetNodeInput(cnode, transpose, 1); + if (kernel_graph == nullptr) { + new_node = std::make_shared(*cnode); + } else { + new_node = kernel_graph->NewCNode(cnode); + } + + } else if (op_name == kBasicLSTMCellWeightGradOpName) { + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_idx = 0; output_idx < out_num; output_idx++) { + auto tuple_getitem = CreatTupleGetItemNode(func_graph, cnode, output_idx); + auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); + if (origin_shape.size() > 1 && output_idx == 0) { + auto dtype = AnfAlgo::GetOutputInferDataType(cnode, output_idx); + auto dst_shape = {origin_shape[0], origin_shape[1]}; + transpose_inputs.push_back(tuple_getitem); + CNodePtr transpose = func_graph->NewCNode(transpose_inputs); + MS_EXCEPTION_IF_NULL(transpose); + AnfAlgo::SetOutputInferTypeAndShape({dtype}, {dst_shape}, transpose.get()); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{1, 0}), transpose); + make_tuple_inputs.push_back(transpose); + } else { + make_tuple_inputs.push_back(tuple_getitem); + } + } + new_node = func_graph->NewCNode(make_tuple_inputs); + } + return new_node; +} + +const AnfNodePtr InsertTranspose::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto op_name = AnfAlgo::GetCNodeName(cnode); + CNodePtr new_node = nullptr; + if (op_name == kBasicLSTMCellInputGradOpName || op_name == kBasicLSTMCellWeightGradOpName) { + new_node = Insert(func_graph, cnode, op_name); + } + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h new file mode 100644 index 00000000000..5245aa1778b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertTranspose : public PatternProcessPass { + public: + explicit InsertTranspose(bool multigraph = true) + : PatternProcessPass("insert_transpose_for_basiclstm_op", multigraph) {} + ~InsertTranspose() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc index a25ebb8cfdf..e07ccc1056f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -24,6 +24,7 @@ namespace opt { const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, + {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { @@ -64,12 +65,13 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n AnfNodePtr new_transdata_node = nullptr; AnfNodePtr new_transpose_node = nullptr; AnfNodePtr new_replace_node = nullptr; + auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); // if output_format=default transdata need split transdata->transpose else transpose->transdata if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { // trans input_format to hwcn new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis); // trans hwcn to default_format new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); @@ -86,7 +88,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n // trans hwcn to output_format new_transdata_node = NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis); new_transdata_node->set_abstract(node->abstract()); new_replace_node = new_transdata_node; } diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 12418873dc8..5e739611d7f 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -56,7 +56,7 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, template T DivCeil(T n1, T n2) { if (n2 != 0) { - return (n1 - 1) / n2 + 1; + return (n1 + n2 - 1) / n2; } return 0; } @@ -444,6 +444,17 @@ std::vector TransShapeToDevice(const std::vector &shape, const s device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize); return device_shape; + } else if (format == kOpFormat_FRACTAL_ZN_LSTM) { + const size_t c0 = 4; + const size_t h = shape.at(kN) / c0; + const size_t i = shape.at(kC) - h; + const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize); + const size_t second = c0 * DivCeil(h, kCubeSize); + device_shape.push_back(first); + device_shape.push_back(second); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); + return device_shape; } if (shape.size() != kNchwDims) { MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f7d905f649f..506ad4a126c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -196,6 +196,9 @@ constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; constexpr auto kTensorAddOpName = "TensorAdd"; constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; +constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; +constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; +constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -324,10 +327,11 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; constexpr auto kOpFormat_NDHWC = "NDHWC"; +constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; const std::set kOpFormatList = { - kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; + kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, + kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM}; const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; const std::set kOptOperatorSet = { kMomentumOpName, @@ -353,9 +357,9 @@ const std::set kOptOperatorSet = { kPullOpName, }; -const std::set kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, - kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, - kOpFormat_FRACTAL_Z_C04}; +const std::set kHWSpecialFormatSet = { + kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, + kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM}; const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py index 76ad1e46076..77eb2e94a65 100644 --- a/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell.py @@ -30,7 +30,7 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \ .input(0, "x", False, "required", "all") \ .input(1, "h", False, "required", "all") \ .input(2, "c", False, "required", "all") \ - .input(3, "w", False, "required", "all") \ + .input(3, "w", False, "required", "all", reshape_type="CN") \ .input(4, "b", False, "required", "all") \ .input(5, "mask", False, "optional", "all") \ .output(0, "ct", False, "required", "all") \ @@ -40,11 +40,11 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \ .output(4, "ft", False, "optional", "all") \ .output(5, "ot", False, "optional", "all") \ .output(6, "tanhct", False, "optional", "all") \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ, + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py index d976d1143b7..222d0b336d1 100644 --- a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py @@ -25,7 +25,7 @@ basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \ .attr("keep_prob", "optional", "float", "all") \ .partial_flag(True) \ .input(0, "dgate", False, "required", "all") \ - .input(1, "w", False, "required", "all") \ + .input(1, "w", False, "required", "all", reshape_type="NC") \ .input(2, "dropout_mask", False, "optional", "all") \ .output(0, "dxt", False, "required", "all") \ .output(1, "dht", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py index 83726bc5105..4b9501a7e42 100644 --- a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py @@ -26,7 +26,7 @@ basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \ .input(0, "x", False, "required", "all") \ .input(1, "h", False, "required", "all") \ .input(2, "dgate", False, "required", "all") \ - .output(0, "dw", False, "required", "all") \ + .output(0, "dw", False, "required", "all", reshape_type="CN") \ .output(1, "db", False, "required", "all") \ .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F32_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/trans_data.py b/mindspore/ops/_op_impl/tbe/trans_data.py index 91b4bc85f59..666902172c4 100644 --- a/mindspore/ops/_op_impl/tbe/trans_data.py +++ b/mindspore/ops/_op_impl/tbe/trans_data.py @@ -129,6 +129,10 @@ trans_data_op_info = TBERegOp("TransData") \ .dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \ .dtype_format(DataType.F16_HWCN, DataType.F16_FracNZ) \ .dtype_format(DataType.F32_HWCN, DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_FracZNLSTM) \ + .dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \ + .dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \ + .dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \ .get_op_info() diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 65d6d2cdb8a..0397b3ecab2 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -619,6 +619,7 @@ class DataType: F16_NHWC = ("float16", "NHWC") F16_HWCN = ("float16", "HWCN") F16_NDHWC = ("float16", "NDHWC") + F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") F32_None = ("float32", "") F32_Default = ("float32", "DefaultFormat") @@ -630,6 +631,7 @@ class DataType: F32_NHWC = ("float32", "NHWC") F32_HWCN = ("float32", "HWCN") F32_NDHWC = ("float32", "NDHWC") + F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") F64_None = ("float64", "") F64_Default = ("float64", "DefaultFormat")