!26370 DynamicRNNGrad support `hidden_size not multiple of 16` scene

Merge pull request !26370 from yuchaojie/ir_fusion4
This commit is contained in:
i-robot 2021-11-24 09:20:30 +00:00 committed by Gitee
commit fa5ea7b3a6
14 changed files with 296 additions and 85 deletions

View File

@ -125,7 +125,7 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
public:
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
SetKernelType(kernel_build_info->kernel_type());
SetFusionType(kernel_build_info->fusion_type());

View File

@ -192,6 +192,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
return make_tuple;
}
} // namespace
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
@ -244,9 +245,9 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
}
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
const AnfNodePtr &trans_node, const std::string &reshape_type, const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_node);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
MS_EXCEPTION_IF_NULL(ori_build_info);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
MS_EXCEPTION_IF_NULL(builder);
@ -258,8 +259,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
builder->SetOutputsDeviceType({type_id});
builder->SetInputsDeviceType({type_id});
}
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
SetTransNodeAttr(trans_data->cast<CNodePtr>());
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_node.get());
SetTransNodeAttr(trans_node->cast<CNodePtr>());
}
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const AnfNodePtr &orig_node,

View File

@ -100,7 +100,7 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::string &reshape_type = {""},
const AnfNodePtr &trans_node, const std::string &reshape_type = {""},
const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const AnfNodePtr &orig_node,

View File

@ -96,7 +96,7 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
for (size_t i = 0; i < in_num; ++i) {
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
<< cnode->DebugString() << "]";
<< cnode->DebugString() << "], fullname: " << node->fullname_with_scope();
}
}
return nullptr;

View File

@ -40,13 +40,15 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
auto input_format = AnfAlgo::GetInputFormat(cnode, index);
auto input_node = AnfAlgo::GetInputNode(cnode, index);
// convert the format of node's input node to default
if (kCommonFormatSet.find(prev_input_format) == kCommonFormatSet.end() && prev_node_out_infer_shape.size() > 1) {
if (kCommonFormatSet.find(prev_input_format) == kCommonFormatSet.end() &&
(prev_node_out_infer_shape.size() > 1 || prev_input_format == kOpFormat_ND_RNN_BIAS)) {
auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, 0, false);
AnfAlgo::SetNodeInput(cnode, trans_node, index);
has_changed = true;
}
// convert node's output format
if (kCommonFormatSet.find(input_format) == kCommonFormatSet.end() && prev_node_out_infer_shape.size() > 1) {
if (kCommonFormatSet.find(input_format) == kCommonFormatSet.end() &&
(prev_node_out_infer_shape.size() > 1 || input_format == kOpFormat_ND_RNN_BIAS)) {
auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, index, true);
AnfAlgo::SetNodeInput(cnode, trans_node, index);
has_changed = true;

View File

@ -15,9 +15,11 @@
*/
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
#include <vector>
#include <string>
#include <memory>
#include "backend/session/kernel_graph.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/ascend/ascend_helper.h"
#include "utils/trace_base.h"
#include "utils/tensor_construct_utils.h"
@ -34,9 +36,40 @@ constexpr int64_t kAttrAxis2Value = 2;
constexpr int64_t kAttrNumSplitValue = 2;
constexpr int64_t kAttrSplitDimValue = 2;
constexpr size_t kDimMultiNum = 4;
void SetAttrInputAndHiddenSize(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
int64_t input_size, int64_t hidden_size) {
auto input = dynamic_rnn_grad_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(input);
// set for input
while (input->isa<CNode>()) {
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(input_size), input);
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(hidden_size), input);
auto input_cnode = input->cast<CNodePtr>();
input = input_cnode->input(kIndex1);
}
if (input->isa<Parameter>()) {
auto param = input->cast<ParameterPtr>();
param->set_input_size(input_size);
param->set_hidden_size(hidden_size);
}
// set for output
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto getitem_index : manager->node_users()[dynamic_rnn_grad_cnode]) {
if (AnfAlgo::CheckPrimitiveType(getitem_index.first, prim::kPrimTupleGetItem)) {
for (auto node_index : manager->node_users()[getitem_index.first]) {
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(input_size), node_index.first);
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(hidden_size), node_index.first);
}
}
}
}
} // namespace
void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const RNNShapeSpecs &specs,
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
@ -45,19 +78,15 @@ void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, co
std::vector<AnfNodePtr> matmul_nodes;
std::vector<AnfNodePtr> split_nodes;
// Get the size of t
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0);
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0];
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
for (size_t i = 0; i < t_size; ++i) {
for (size_t i = 0; i < specs.t_size; ++i) {
// Create basic_lstm_cell_c_state_grad
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
std::vector<size_t> output0_dims{
origin_input9_shape[kDim0],
kDimMultiNum * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
std::vector<size_t> output0_dims{specs.batch_size, kDimMultiNum * specs.hidden_nz_size * kCubeSize};
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
basic_lstm_cell_c_state_grad.get());
@ -66,30 +95,40 @@ void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, co
// Create matmul
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
std::vector<AnfNodePtr> matmul_inputs;
if (specs.shape_need_align) {
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMulV2->name())));
} else {
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
}
auto matmul = NewCNode(matmul_inputs, func_graph);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
matmul.get());
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
if (specs.shape_need_align) {
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), matmul);
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), matmul);
}
// Create split
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
auto split_v = NewCNode(splitv_input, func_graph);
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
std::vector<size_t> split_v_output0_shape{IntToSize(1), specs.batch_size, specs.input_size};
std::vector<size_t> split_v_output1_shape{IntToSize(1), specs.batch_size, specs.hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
MakeValue(std::vector<int64_t>{
SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
MakeValue(std::vector<int64_t>{SizeToLong(specs.input_nz_size * kCubeSize),
SizeToLong(specs.hidden_nz_size * kCubeSize)}),
split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v);
if (specs.shape_need_align) {
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ}), split_v);
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}),
split_v);
}
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
matmul_nodes.emplace_back(matmul);
@ -117,7 +156,7 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateLSTMSPlitV(const FuncGraphPtr &func_gr
void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const std::vector<std::vector<AnfNodePtr>> &result_nodes,
size_t num_split_x,
size_t num_split_x, bool shape_need_align,
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const {
auto &basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
auto &matmul_nodes = result_nodes[kIndex1];
@ -166,7 +205,6 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
// Create outputs for current basic_lstm_cell_c_state_grad node
@ -176,11 +214,15 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
// Create MatMul
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
std::vector<AnfNodePtr> matmul_inputs;
if (shape_need_align) {
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMulV2->name())));
} else {
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
}
(void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
(void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
auto matmul = NewCNode(matmul_inputs, func_graph);
MS_EXCEPTION_IF_NULL(matmul);
matmul->set_abstract(matmul_nodes[i]->abstract());
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
@ -188,7 +230,6 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
matmul};
auto split_v = NewCNode(splitv_input, func_graph);
MS_EXCEPTION_IF_NULL(split_v);
split_v->set_abstract(split_nodes[i]->abstract());
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
@ -223,9 +264,10 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const RNNShapeSpecs &specs,
std::vector<AnfNodePtr> *outputs) const {
std::vector<std::vector<AnfNodePtr>> result_nodes;
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, specs, &result_nodes);
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
@ -290,7 +332,8 @@ AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &fun
// Add edges
std::vector<std::vector<AnfNodePtr>> loop_node_outputs;
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, &loop_node_outputs);
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, specs.shape_need_align,
&loop_node_outputs);
auto &pre_basic_lstm_cell_c_state_grad_outputs = loop_node_outputs[kIndex0];
auto &pre_split_outputs = loop_node_outputs[kIndex1];
auto &lstm_x_concat_input = loop_node_outputs[kIndex2];
@ -306,10 +349,8 @@ AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &fun
// Create lstm_gage_concat
auto lstm_gage_concat = NewCNode(lstm_gage_concat_input, func_graph);
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
AnfAlgo::SetOutputInferTypeAndShape(
{kNumberTypeFloat16},
{{origin_input7_shape[kDim0], origin_input7_shape[kDim1], kDimMultiNum * origin_input7_shape[kDim2]}},
{kNumberTypeFloat16}, {{specs.t_size, specs.batch_size, kDimMultiNum * specs.hidden_nz_size * kCubeSize}},
lstm_gage_concat.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
@ -484,37 +525,69 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul2(const FuncGraphPtr &func_
return batch_matmul;
}
CNodePtr DynamicRnnGradFissionV2::CreateTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &dw_reduce_sum,
const RNNShapeSpecs &specs) const {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTranspose->name())),
dw_reduce_sum};
auto transpose = NewCNode(transpose_inputs, func_graph);
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size, kDimMultiNum * specs.hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dw_reduce_sum, 0)}, {out_shape},
transpose.get());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{1, 0, 2, 3}), transpose);
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), transpose);
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), transpose);
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ}), transpose);
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_FRACTAL_ZN_RNN}), transpose);
return transpose;
}
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReduceSum(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const {
const AnfNodePtr &batch_matmul,
const RNNShapeSpecs &specs) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
batch_matmul};
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
// Set infer data type and shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size,
kDimMultiNum * specs.hidden_nz_size * kCubeSize};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, {out_shape},
reduce_sum.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
return reduce_sum;
auto ret_node = reduce_sum;
if (specs.shape_need_align) {
ret_node = CreateTranspose(func_graph, reduce_sum, specs);
}
return ret_node;
}
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReshape(const FuncGraphPtr &func_graph,
const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const {
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
batch_matmul};
auto reshape = NewCNode(reshape_inputs, func_graph);
// Set infer data type and shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size,
kDimMultiNum * specs.hidden_nz_size * kCubeSize};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, {out_shape},
reshape.get());
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
return reshape;
auto ret_node = reshape;
if (specs.shape_need_align) {
ret_node = CreateTranspose(func_graph, reshape, specs);
}
return ret_node;
}
AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_graph,
@ -537,8 +610,8 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_gra
}
AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &,
const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &value_node) const {
const AnfNodePtr &lstm_input_grad, const AnfNodePtr &value_node,
const RNNShapeSpecs &specs) const {
MS_EXCEPTION_IF_NULL(func_graph);
// Create node
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
@ -546,12 +619,18 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_g
batch_matmul};
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
// Set infer data type and shape
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
std::vector<size_t> out_shape = {kDimMultiNum * specs.hidden_size};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
if (specs.shape_need_align) {
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_DEFAULT}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_ND_RNN_BIAS}), reduce_sum);
}
return reduce_sum;
}
@ -572,20 +651,28 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
return nullptr;
}
if (AnfAlgo::IsDynamicShape(node)) {
MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission.";
MS_LOG(INFO) << "DynamicRNNGrad is dynamic shape, can not do fission.";
return nullptr;
}
std::vector<AnfNodePtr> new_outputs;
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0];
size_t hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2];
if (hidden_size % kCubeSize != 0) {
MS_LOG(EXCEPTION) << "`hidden_size` in this node should be multiple of 16, but got " << hidden_size << ". "
<< dynamic_rnn_grad_cnode->DebugString();
auto input0_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex1), 0);
RNNShapeSpecs specs;
specs.t_size = input0_shape[0];
specs.batch_size = input0_shape[1];
specs.input_size = input0_shape[kDim2];
specs.hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2];
if (specs.hidden_size % kCubeSize != 0) {
specs.shape_need_align = true;
SetAttrInputAndHiddenSize(func_graph, dynamic_rnn_grad_cnode, SizeToLong(specs.input_size),
SizeToLong(specs.hidden_size));
}
specs.batch_nz_size = (specs.batch_size + kCubeSize - 1) / kCubeSize;
specs.input_nz_size = (specs.input_size + kCubeSize - 1) / kCubeSize;
specs.hidden_nz_size = (specs.hidden_size + kCubeSize - 1) / kCubeSize;
std::vector<AnfNodePtr> new_outputs;
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, specs, &new_outputs);
AnfNodePtr concat = nullptr;
if (t_size != 1) {
if (specs.t_size != 1) {
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv);
concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat);
@ -595,17 +682,17 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat);
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
if (t_size != 1) {
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
if (specs.t_size != 1) {
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul, specs);
make_tuple_inputs.emplace_back(dw_reduce_sum);
} else {
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul, specs);
make_tuple_inputs.emplace_back(dw_reshape);
}
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
// create reduce_sum_2
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node);
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node, specs);
make_tuple_inputs.emplace_back(db_reduce_sum);
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);

View File

@ -22,6 +22,17 @@
namespace mindspore {
namespace opt {
struct RNNShapeSpecs {
size_t t_size;
size_t batch_size;
size_t input_size;
size_t hidden_size;
size_t batch_nz_size;
size_t input_nz_size;
size_t hidden_nz_size;
bool shape_need_align = false;
};
class DynamicRnnGradFissionV2 : public PatternProcessPass {
public:
explicit DynamicRnnGradFissionV2(bool multigraph = true)
@ -32,16 +43,16 @@ class DynamicRnnGradFissionV2 : public PatternProcessPass {
private:
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
const RNNShapeSpecs &specs, std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const std::vector<std::vector<size_t>> &split_shapes,
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
size_t num_split_x) const;
void CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const std::vector<std::vector<AnfNodePtr>> &result_nodes, size_t num_split_x,
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const;
bool shape_need_align, std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const;
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<AnfNodePtr> *outputs) const;
const RNNShapeSpecs &specs, std::vector<AnfNodePtr> *outputs) const;
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &splitv) const;
@ -52,13 +63,15 @@ class DynamicRnnGradFissionV2 : public PatternProcessPass {
const AnfNodePtr &concat) const;
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &node) const;
CNodePtr CreateTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &dw_reduce_sum,
const RNNShapeSpecs &specs) const;
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const;
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const;
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
const AnfNodePtr &batch_matmul) const;
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const;
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
const AnfNodePtr &value_node) const;
const AnfNodePtr &value_node, const RNNShapeSpecs &specs) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -739,7 +739,9 @@ std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
} else if (dim_last2 == input_size + hidden_size) {
device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
} else {
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid. Should be equal to `input_size` or "
"`hidden_size` or `input_size + hidden_size`, but got second-last dim value: "
<< dim_last2 << " input_size: " << input_size << " hidden_size: " << hidden_size;
}
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
device_shape.push_back(NUM16);
@ -754,8 +756,8 @@ std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &sha
}
int64_t input_size = input_hidden_size[0];
int64_t hidden_size = input_hidden_size[1];
auto dim_last1 = shape[shape.size() - 1];
auto dim_last2 = shape[shape.size() - 2];
auto dim_last1 = shape[shape.size() - kDim1];
auto dim_last2 = shape[shape.size() - kDim2];
const int64_t NUM16 = 16;
const int64_t C0 = SizeToLong(kCubeSize);
@ -767,7 +769,9 @@ std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &sha
} else if (dim_last2 == input_size + hidden_size) {
device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
} else {
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid. Should be equal to `input_size` or "
"`hidden_size` or `input_size + hidden_size` or `-1`, but got second-last dim value: "
<< dim_last2 << " input_size: " << input_size << " hidden_size: " << hidden_size;
}
if (dim_last1 == Shape::SHP_ANY) {
device_shape[shape.size() - kDim1] = Shape::SHP_ANY;
@ -857,18 +861,25 @@ int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
if (!node->isa<CNode>()) {
if (!node->isa<CNode>() && !node->isa<Parameter>()) {
return input_hidden_size;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
MS_LOG(EXCEPTION)
<< "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
<< cnode->DebugString();
if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
input_hidden_size[0] = param->input_size();
input_hidden_size[1] = param->hidden_size();
} else {
CNodePtr cnode = node->cast<CNodePtr>();
if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) ||
!AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
MS_LOG(EXCEPTION)
<< "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
<< node->DebugString();
}
input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrInputSize);
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrHiddenSize);
}
input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrInputSize);
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
return input_hidden_size;
}

View File

@ -587,7 +587,25 @@ void FillNoneInKernelInfo(const CNodePtr &kernel_node, std::vector<kernel::Kerne
(*kernel_info_list)[idx] = builder->Build();
}
}
void ResetPreFixedFormat(const CNodePtr &kernel_node, kernel::KernelBuildInfoPtr *selected_kernel_info) {
if (!AnfAlgo::HasNodeAttr(kAttrFixedInputFormat, kernel_node) ||
!AnfAlgo::HasNodeAttr(kAttrFixedOutputFormat, kernel_node)) {
return;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(*selected_kernel_info);
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputsFormat(AnfAlgo::GetNodeAttr<std::vector<string>>(kernel_node, kAttrFixedInputFormat));
builder->SetOutputsFormat(AnfAlgo::GetNodeAttr<std::vector<string>>(kernel_node, kAttrFixedOutputFormat));
*selected_kernel_info = builder->Build();
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
<< " selected kernel build info after reset fixed format: " << (*selected_kernel_info)->ToString();
AnfAlgo::EraseNodeAttr(kAttrFixedInputFormat, kernel_node);
AnfAlgo::EraseNodeAttr(kAttrFixedOutputFormat, kernel_node);
}
} // namespace
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
@ -613,14 +631,14 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
}
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
const std::vector<kernel::KernelBuildInfoPtr> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
KernelSelectStatus select_status = kNoMatched;
if (kernel_info_list.empty()) {
return select_status;
}
bool precision_reduce = false;
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
kernel::KernelBuildInfoPtr selected_kernel_info = nullptr;
// Matched kernel info
// Filter kernel info matched with me inferred type
auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
@ -642,6 +660,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
// Set kernel build info to node
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
<< " selected: " << selected_kernel_info->ToString();
ResetPreFixedFormat(kernel_node, &selected_kernel_info);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {

View File

@ -299,6 +299,7 @@ constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
constexpr auto kMatMulOpName = "MatMul";
constexpr auto kMatMulV2OpName = "MatMulV2";
constexpr auto kBatchMatMulOpName = "BatchMatMul";
constexpr auto kBatchMatMulV2OpName = "BatchMatMulV2";
constexpr auto kBroadcastToOpName = "BroadcastTo";
constexpr auto kFusedAddReluV2Name = "FusedAddReluV2";
constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2";
@ -493,6 +494,8 @@ constexpr auto kAttrInputSize = "input_size";
constexpr auto kAttrDstType = "dst_type";
constexpr auto kAttrDump = "dump";
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
constexpr auto kAttrFixedInputFormat = "fixed_input_format";
constexpr auto kAttrFixedOutputFormat = "fixed_output_format";
constexpr auto kAttrFuncType = "func_type";
constexpr auto kAttrCustAicpu = "cust_aicpu";

View File

@ -497,6 +497,7 @@ inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>(kAdd);
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
inline const PrimitivePtr kPrimBatchMatMulV2 = std::make_shared<Primitive>("BatchMatMulV2");
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce");

View File

@ -831,14 +831,39 @@ class MS_CORE_API Parameter final : public ANode {
/// \brief Set groups attr in FracZ format.
///
/// \param[in] fracz_group Groups attr in FracZ format.
void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; }
void set_fracz_group(int64_t fracz_group) { format_attrs_.fracz_group = fracz_group; }
/// \brief Get groups attr in FracZ format.
///
/// \return Groups attr in FracZ format.
int64_t fracz_group() { return fracz_group_; }
int64_t fracz_group() { return format_attrs_.fracz_group; }
/// \brief Set input_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
/// \param[in] input_size input_size attr in FracNZ_RNN or ND_RNN_Bias format.
void set_input_size(int64_t input_size) { format_attrs_.input_size = input_size; }
/// \brief Get input_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
/// \return input_size attr in FracNZ_RNN or ND_RNN_Bias format.
int64_t input_size() { return format_attrs_.input_size; }
/// \brief Set hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
/// \param[in] hidden_size hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
void set_hidden_size(int64_t hidden_size) { format_attrs_.hidden_size = hidden_size; }
/// \brief Get hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
/// \return hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
int64_t hidden_size() { return format_attrs_.hidden_size; }
private:
struct FormatAttr {
int64_t fracz_group = 1;
int64_t input_size = 0;
int64_t hidden_size = 0;
};
std::string name_;
bool has_default_;
std::set<uint32_t> not_used_in_graphs_;
@ -846,8 +871,8 @@ class MS_CORE_API Parameter final : public ANode {
ValuePtr default_param_;
// The count of graphs using the parameter.
int used_graph_count_;
// groups attr in FracZ format
int64_t fracz_group_ = 1;
// some attrs used in special format
FormatAttr format_attrs_;
bool is_top_graph_param_ = false;
};
using ParameterPtr = std::shared_ptr<Parameter>;

View File

@ -55,6 +55,7 @@ from .assign_add_ds import _assign_add_ds_tbe
from .assign_sub import _assign_sub_tbe
from .batch_matmul import _batch_matmul_tbe
from .batch_matmul_ds import _batch_matmul_ds_tbe
from .batch_matmul_v2 import _batch_matmul_v2_tbe
from .batchnorm import _batch_norm_tbe
from .batchnorm_grad import _batch_norm_grad_tbe
from .bias_add import _bias_add_tbe

View File

@ -0,0 +1,48 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
batch_matmul_v2_op_info = TBERegOp("BatchMatMulV2") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batch_matmul_v2.so") \
.compute_cost(10) \
.kernel_name("batch_matmul_v2") \
.attr("transpose_x1", "required", "bool", "all") \
.attr("transpose_x2", "required", "bool", "all") \
.attr("offset_x", "optional", "int", "all", "0") \
.partial_flag(True) \
.need_check_supported(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.is_dynamic_format(True) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None,
DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
DataType.F32_None) \
.get_op_info()
@op_info_register(batch_matmul_v2_op_info)
def _batch_matmul_v2_tbe():
"""BatchMatMulV2 TBE register"""
return