!5883 support for frac_zn_lstm

Merge pull request !5883 from liubuyu/master
This commit is contained in:
mindspore-ci-bot 2020-09-11 09:30:02 +08:00 committed by Gitee
commit 7b3873559f
12 changed files with 194 additions and 14 deletions

View File

@ -60,6 +60,7 @@
#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h"
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
@ -286,6 +287,9 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>());
ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>());
ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -142,6 +142,15 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
return node;
}
void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(node);
if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
}
}
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -149,6 +158,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
std::string op_name;
if (node->isa<CNode>()) {
op_name = AnfAlgo::GetCNodeName(node);
}
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
if (output_format == kOpFormat_NC1KHKWHWC0) {
@ -159,6 +172,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) {
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
ReFreshInferShape(trans_op, op_name);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
}

View File

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
const BaseRef InsertTranspose::DefinePattern() const {
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
CNodePtr new_node = nullptr;
std::vector<AnfNodePtr> transpose_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
transpose_inputs.push_back(NewValueNode(prim));
if (op_name == kBasicLSTMCellInputGradOpName) {
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 1);
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
auto dst_shape = {origin_shape[1], origin_shape[0]};
transpose_inputs.push_back(AnfAlgo::GetInputNode(cnode, 1));
CNodePtr transpose = func_graph->NewCNode(transpose_inputs);
MS_EXCEPTION_IF_NULL(transpose);
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {dst_shape}, transpose.get());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{1, 0}), transpose);
AnfAlgo::SetNodeInput(cnode, transpose, 1);
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
}
} else if (op_name == kBasicLSTMCellWeightGradOpName) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_idx = 0; output_idx < out_num; output_idx++) {
auto tuple_getitem = CreatTupleGetItemNode(func_graph, cnode, output_idx);
auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
if (origin_shape.size() > 1 && output_idx == 0) {
auto dtype = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
auto dst_shape = {origin_shape[0], origin_shape[1]};
transpose_inputs.push_back(tuple_getitem);
CNodePtr transpose = func_graph->NewCNode(transpose_inputs);
MS_EXCEPTION_IF_NULL(transpose);
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {dst_shape}, transpose.get());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{1, 0}), transpose);
make_tuple_inputs.push_back(transpose);
} else {
make_tuple_inputs.push_back(tuple_getitem);
}
}
new_node = func_graph->NewCNode(make_tuple_inputs);
}
return new_node;
}
const AnfNodePtr InsertTranspose::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
CNodePtr new_node = nullptr;
if (op_name == kBasicLSTMCellInputGradOpName || op_name == kBasicLSTMCellWeightGradOpName) {
new_node = Insert(func_graph, cnode, op_name);
}
return new_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H
#include <string>
#include <utility>
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertTranspose : public PatternProcessPass {
public:
explicit InsertTranspose(bool multigraph = true)
: PatternProcessPass("insert_transpose_for_basiclstm_op", multigraph) {}
~InsertTranspose() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H

View File

@ -24,6 +24,7 @@ namespace opt {
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
{kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
@ -64,12 +65,13 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
AnfNodePtr new_transdata_node = nullptr;
AnfNodePtr new_transpose_node = nullptr;
AnfNodePtr new_replace_node = nullptr;
auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
// trans input_format to hwcn
new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis);
// trans hwcn to default_format
new_transpose_node =
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
@ -86,7 +88,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
// trans hwcn to output_format
new_transdata_node =
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node);
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis);
new_transdata_node->set_abstract(node->abstract());
new_replace_node = new_transdata_node;
}

View File

@ -56,7 +56,7 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
template <typename T>
T DivCeil(T n1, T n2) {
if (n2 != 0) {
return (n1 - 1) / n2 + 1;
return (n1 + n2 - 1) / n2;
}
return 0;
}
@ -444,6 +444,17 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
const size_t c0 = 4;
const size_t h = shape.at(kN) / c0;
const size_t i = shape.at(kC) - h;
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
const size_t second = c0 * DivCeil(h, kCubeSize);
device_shape.push_back(first);
device_shape.push_back(second);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
if (shape.size() != kNchwDims) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";

View File

@ -216,6 +216,9 @@ constexpr auto kReduceMinOpName = "ReduceMin";
constexpr auto kReduceMaxOpName = "ReduceMax";
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";
constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
// attr key name
constexpr auto kAttrInputNames = "input_names";
@ -344,10 +347,11 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
constexpr auto kOpFormat_NDHWC = "NDHWC";
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
const std::set<std::string> kOpFormatList = {
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC};
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
const std::set<std::string> kOptOperatorSet = {
kMomentumOpName,
@ -373,9 +377,9 @@ const std::set<std::string> kOptOperatorSet = {
kPullOpName,
};
const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};
const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM};
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};

View File

@ -30,7 +30,7 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "c", False, "required", "all") \
.input(3, "w", False, "required", "all") \
.input(3, "w", False, "required", "all", reshape_type="CN") \
.input(4, "b", False, "required", "all") \
.input(5, "mask", False, "optional", "all") \
.output(0, "ct", False, "required", "all") \
@ -40,11 +40,11 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
.output(4, "ft", False, "optional", "all") \
.output(5, "ot", False, "optional", "all") \
.output(6, "tanhct", False, "optional", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ,
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZNLSTM,
DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZNLSTM,
DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \

View File

@ -25,7 +25,7 @@ basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \
.attr("keep_prob", "optional", "float", "all") \
.partial_flag(True) \
.input(0, "dgate", False, "required", "all") \
.input(1, "w", False, "required", "all") \
.input(1, "w", False, "required", "all", reshape_type="NC") \
.input(2, "dropout_mask", False, "optional", "all") \
.output(0, "dxt", False, "required", "all") \
.output(1, "dht", False, "required", "all") \

View File

@ -26,7 +26,7 @@ basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "dgate", False, "required", "all") \
.output(0, "dw", False, "required", "all") \
.output(0, "dw", False, "required", "all", reshape_type="CN") \
.output(1, "db", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
DataType.F32_Default) \

View File

@ -129,6 +129,10 @@ trans_data_op_info = TBERegOp("TransData") \
.dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \
.dtype_format(DataType.F16_HWCN, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_HWCN, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_HWCN, DataType.F16_FracZNLSTM) \
.dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \
.dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \
.dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \
.get_op_info()

View File

@ -619,6 +619,7 @@ class DataType:
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
@ -630,6 +631,7 @@ class DataType:
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")