forked from mindspore-Ecosystem/mindspore
!26370 DynamicRNNGrad support `hidden_size not multiple of 16` scene
Merge pull request !26370 from yuchaojie/ir_fusion4
This commit is contained in:
commit
fa5ea7b3a6
|
@ -125,7 +125,7 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
||||||
public:
|
public:
|
||||||
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
|
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
|
||||||
|
|
||||||
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
|
explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
|
||||||
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
|
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
|
||||||
SetKernelType(kernel_build_info->kernel_type());
|
SetKernelType(kernel_build_info->kernel_type());
|
||||||
SetFusionType(kernel_build_info->fusion_type());
|
SetFusionType(kernel_build_info->fusion_type());
|
||||||
|
|
|
@ -192,6 +192,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
||||||
return make_tuple;
|
return make_tuple;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
||||||
AnfNodePtr trans_node = nullptr;
|
AnfNodePtr trans_node = nullptr;
|
||||||
|
@ -244,9 +245,9 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
||||||
}
|
}
|
||||||
|
|
||||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||||
const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) {
|
const AnfNodePtr &trans_node, const std::string &reshape_type, const TypeId &type_id) {
|
||||||
MS_EXCEPTION_IF_NULL(trans_data);
|
MS_EXCEPTION_IF_NULL(trans_node);
|
||||||
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
|
||||||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
MS_EXCEPTION_IF_NULL(builder);
|
||||||
|
@ -258,8 +259,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
||||||
builder->SetOutputsDeviceType({type_id});
|
builder->SetOutputsDeviceType({type_id});
|
||||||
builder->SetInputsDeviceType({type_id});
|
builder->SetInputsDeviceType({type_id});
|
||||||
}
|
}
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_node.get());
|
||||||
SetTransNodeAttr(trans_data->cast<CNodePtr>());
|
SetTransNodeAttr(trans_node->cast<CNodePtr>());
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const AnfNodePtr &orig_node,
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const AnfNodePtr &orig_node,
|
||||||
|
|
|
@ -100,7 +100,7 @@ class OpFinder {
|
||||||
using OpFinderPtr = std::shared_ptr<OpFinder>;
|
using OpFinderPtr = std::shared_ptr<OpFinder>;
|
||||||
|
|
||||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||||
const AnfNodePtr &trans_data, const std::string &reshape_type = {""},
|
const AnfNodePtr &trans_node, const std::string &reshape_type = {""},
|
||||||
const TypeId &type_id = kTypeUnknown);
|
const TypeId &type_id = kTypeUnknown);
|
||||||
|
|
||||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const AnfNodePtr &orig_node,
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const AnfNodePtr &orig_node,
|
||||||
|
|
|
@ -96,7 +96,7 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
|
||||||
for (size_t i = 0; i < in_num; ++i) {
|
for (size_t i = 0; i < in_num; ++i) {
|
||||||
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
||||||
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
||||||
<< cnode->DebugString() << "]";
|
<< cnode->DebugString() << "], fullname: " << node->fullname_with_scope();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -40,13 +40,15 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
|
||||||
auto input_format = AnfAlgo::GetInputFormat(cnode, index);
|
auto input_format = AnfAlgo::GetInputFormat(cnode, index);
|
||||||
auto input_node = AnfAlgo::GetInputNode(cnode, index);
|
auto input_node = AnfAlgo::GetInputNode(cnode, index);
|
||||||
// convert the format of node's input node to default
|
// convert the format of node's input node to default
|
||||||
if (kCommonFormatSet.find(prev_input_format) == kCommonFormatSet.end() && prev_node_out_infer_shape.size() > 1) {
|
if (kCommonFormatSet.find(prev_input_format) == kCommonFormatSet.end() &&
|
||||||
|
(prev_node_out_infer_shape.size() > 1 || prev_input_format == kOpFormat_ND_RNN_BIAS)) {
|
||||||
auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, 0, false);
|
auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, 0, false);
|
||||||
AnfAlgo::SetNodeInput(cnode, trans_node, index);
|
AnfAlgo::SetNodeInput(cnode, trans_node, index);
|
||||||
has_changed = true;
|
has_changed = true;
|
||||||
}
|
}
|
||||||
// convert node's output format
|
// convert node's output format
|
||||||
if (kCommonFormatSet.find(input_format) == kCommonFormatSet.end() && prev_node_out_infer_shape.size() > 1) {
|
if (kCommonFormatSet.find(input_format) == kCommonFormatSet.end() &&
|
||||||
|
(prev_node_out_infer_shape.size() > 1 || input_format == kOpFormat_ND_RNN_BIAS)) {
|
||||||
auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, index, true);
|
auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, index, true);
|
||||||
AnfAlgo::SetNodeInput(cnode, trans_node, index);
|
AnfAlgo::SetNodeInput(cnode, trans_node, index);
|
||||||
has_changed = true;
|
has_changed = true;
|
||||||
|
|
|
@ -15,9 +15,11 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||||
#include "utils/trace_base.h"
|
#include "utils/trace_base.h"
|
||||||
#include "utils/tensor_construct_utils.h"
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
|
||||||
|
@ -34,9 +36,40 @@ constexpr int64_t kAttrAxis2Value = 2;
|
||||||
constexpr int64_t kAttrNumSplitValue = 2;
|
constexpr int64_t kAttrNumSplitValue = 2;
|
||||||
constexpr int64_t kAttrSplitDimValue = 2;
|
constexpr int64_t kAttrSplitDimValue = 2;
|
||||||
constexpr size_t kDimMultiNum = 4;
|
constexpr size_t kDimMultiNum = 4;
|
||||||
|
|
||||||
|
void SetAttrInputAndHiddenSize(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
|
int64_t input_size, int64_t hidden_size) {
|
||||||
|
auto input = dynamic_rnn_grad_cnode->input(kIndex2);
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
// set for input
|
||||||
|
while (input->isa<CNode>()) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(input_size), input);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(hidden_size), input);
|
||||||
|
auto input_cnode = input->cast<CNodePtr>();
|
||||||
|
input = input_cnode->input(kIndex1);
|
||||||
|
}
|
||||||
|
if (input->isa<Parameter>()) {
|
||||||
|
auto param = input->cast<ParameterPtr>();
|
||||||
|
param->set_input_size(input_size);
|
||||||
|
param->set_hidden_size(hidden_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set for output
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
for (auto getitem_index : manager->node_users()[dynamic_rnn_grad_cnode]) {
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(getitem_index.first, prim::kPrimTupleGetItem)) {
|
||||||
|
for (auto node_index : manager->node_users()[getitem_index.first]) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(input_size), node_index.first);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(hidden_size), node_index.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
|
const RNNShapeSpecs &specs,
|
||||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||||
|
@ -45,19 +78,15 @@ void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, co
|
||||||
std::vector<AnfNodePtr> matmul_nodes;
|
std::vector<AnfNodePtr> matmul_nodes;
|
||||||
std::vector<AnfNodePtr> split_nodes;
|
std::vector<AnfNodePtr> split_nodes;
|
||||||
// Get the size of t
|
// Get the size of t
|
||||||
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0);
|
|
||||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0];
|
|
||||||
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
|
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
|
||||||
|
|
||||||
for (size_t i = 0; i < t_size; ++i) {
|
for (size_t i = 0; i < specs.t_size; ++i) {
|
||||||
// Create basic_lstm_cell_c_state_grad
|
// Create basic_lstm_cell_c_state_grad
|
||||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
|
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
|
||||||
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
||||||
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
|
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
|
||||||
|
|
||||||
std::vector<size_t> output0_dims{
|
std::vector<size_t> output0_dims{specs.batch_size, kDimMultiNum * specs.hidden_nz_size * kCubeSize};
|
||||||
origin_input9_shape[kDim0],
|
|
||||||
kDimMultiNum * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
|
|
||||||
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
|
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
||||||
basic_lstm_cell_c_state_grad.get());
|
basic_lstm_cell_c_state_grad.get());
|
||||||
|
@ -66,30 +95,40 @@ void DynamicRnnGradFissionV2::CreateTLoopNode(const FuncGraphPtr &func_graph, co
|
||||||
|
|
||||||
// Create matmul
|
// Create matmul
|
||||||
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
|
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
|
||||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
std::vector<AnfNodePtr> matmul_inputs;
|
||||||
|
if (specs.shape_need_align) {
|
||||||
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMulV2->name())));
|
||||||
|
} else {
|
||||||
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
|
||||||
|
}
|
||||||
auto matmul = NewCNode(matmul_inputs, func_graph);
|
auto matmul = NewCNode(matmul_inputs, func_graph);
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
|
||||||
matmul.get());
|
matmul.get());
|
||||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
|
||||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
|
||||||
|
if (specs.shape_need_align) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), matmul);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), matmul);
|
||||||
|
}
|
||||||
|
|
||||||
// Create split
|
// Create split
|
||||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||||
auto split_v = NewCNode(splitv_input, func_graph);
|
auto split_v = NewCNode(splitv_input, func_graph);
|
||||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
|
std::vector<size_t> split_v_output0_shape{IntToSize(1), specs.batch_size, specs.input_size};
|
||||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
|
std::vector<size_t> split_v_output1_shape{IntToSize(1), specs.batch_size, specs.hidden_size};
|
||||||
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
|
|
||||||
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
|
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
||||||
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
|
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
|
||||||
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
|
||||||
MakeValue(std::vector<int64_t>{
|
MakeValue(std::vector<int64_t>{SizeToLong(specs.input_nz_size * kCubeSize),
|
||||||
SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
|
SizeToLong(specs.hidden_nz_size * kCubeSize)}),
|
||||||
SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
|
|
||||||
split_v);
|
split_v);
|
||||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v);
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v);
|
||||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v);
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v);
|
||||||
|
if (specs.shape_need_align) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ}), split_v);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ, kOpFormat_FRAC_NZ}),
|
||||||
|
split_v);
|
||||||
|
}
|
||||||
|
|
||||||
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
|
||||||
matmul_nodes.emplace_back(matmul);
|
matmul_nodes.emplace_back(matmul);
|
||||||
|
@ -117,7 +156,7 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateLSTMSPlitV(const FuncGraphPtr &func_gr
|
||||||
void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph,
|
void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph,
|
||||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const std::vector<std::vector<AnfNodePtr>> &result_nodes,
|
const std::vector<std::vector<AnfNodePtr>> &result_nodes,
|
||||||
size_t num_split_x,
|
size_t num_split_x, bool shape_need_align,
|
||||||
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const {
|
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const {
|
||||||
auto &basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
|
auto &basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
|
||||||
auto &matmul_nodes = result_nodes[kIndex1];
|
auto &matmul_nodes = result_nodes[kIndex1];
|
||||||
|
@ -166,7 +205,6 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
|
||||||
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
|
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
|
||||||
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
|
(void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
|
||||||
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
|
auto basic_lstm_cell_c_state_grad = NewCNode(basic_lstm_cell_c_state_grad_inputs, func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
|
|
||||||
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
|
basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
|
||||||
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
||||||
// Create outputs for current basic_lstm_cell_c_state_grad node
|
// Create outputs for current basic_lstm_cell_c_state_grad node
|
||||||
|
@ -176,11 +214,15 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
|
||||||
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
|
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
|
||||||
|
|
||||||
// Create MatMul
|
// Create MatMul
|
||||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
std::vector<AnfNodePtr> matmul_inputs;
|
||||||
|
if (shape_need_align) {
|
||||||
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMulV2->name())));
|
||||||
|
} else {
|
||||||
|
matmul_inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
|
||||||
|
}
|
||||||
(void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
|
(void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
|
||||||
(void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
|
(void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
|
||||||
auto matmul = NewCNode(matmul_inputs, func_graph);
|
auto matmul = NewCNode(matmul_inputs, func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(matmul);
|
|
||||||
matmul->set_abstract(matmul_nodes[i]->abstract());
|
matmul->set_abstract(matmul_nodes[i]->abstract());
|
||||||
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
|
AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
|
||||||
|
|
||||||
|
@ -188,7 +230,6 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
|
||||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||||
matmul};
|
matmul};
|
||||||
auto split_v = NewCNode(splitv_input, func_graph);
|
auto split_v = NewCNode(splitv_input, func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(split_v);
|
|
||||||
split_v->set_abstract(split_nodes[i]->abstract());
|
split_v->set_abstract(split_nodes[i]->abstract());
|
||||||
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
|
AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
|
||||||
|
|
||||||
|
@ -223,9 +264,10 @@ void DynamicRnnGradFissionV2::CreateTLoopNodeWithEdge(const FuncGraphPtr &func_g
|
||||||
|
|
||||||
AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &func_graph,
|
AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &func_graph,
|
||||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
|
const RNNShapeSpecs &specs,
|
||||||
std::vector<AnfNodePtr> *outputs) const {
|
std::vector<AnfNodePtr> *outputs) const {
|
||||||
std::vector<std::vector<AnfNodePtr>> result_nodes;
|
std::vector<std::vector<AnfNodePtr>> result_nodes;
|
||||||
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
|
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, specs, &result_nodes);
|
||||||
|
|
||||||
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
|
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
|
||||||
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
|
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
|
||||||
|
@ -290,7 +332,8 @@ AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &fun
|
||||||
|
|
||||||
// Add edges
|
// Add edges
|
||||||
std::vector<std::vector<AnfNodePtr>> loop_node_outputs;
|
std::vector<std::vector<AnfNodePtr>> loop_node_outputs;
|
||||||
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, &loop_node_outputs);
|
CreateTLoopNodeWithEdge(func_graph, dynamic_rnn_grad_cnode, result_nodes, num_split_x, specs.shape_need_align,
|
||||||
|
&loop_node_outputs);
|
||||||
auto &pre_basic_lstm_cell_c_state_grad_outputs = loop_node_outputs[kIndex0];
|
auto &pre_basic_lstm_cell_c_state_grad_outputs = loop_node_outputs[kIndex0];
|
||||||
auto &pre_split_outputs = loop_node_outputs[kIndex1];
|
auto &pre_split_outputs = loop_node_outputs[kIndex1];
|
||||||
auto &lstm_x_concat_input = loop_node_outputs[kIndex2];
|
auto &lstm_x_concat_input = loop_node_outputs[kIndex2];
|
||||||
|
@ -306,10 +349,8 @@ AnfNodePtr DynamicRnnGradFissionV2::AddLSTMInputGradNode(const FuncGraphPtr &fun
|
||||||
|
|
||||||
// Create lstm_gage_concat
|
// Create lstm_gage_concat
|
||||||
auto lstm_gage_concat = NewCNode(lstm_gage_concat_input, func_graph);
|
auto lstm_gage_concat = NewCNode(lstm_gage_concat_input, func_graph);
|
||||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
|
||||||
AnfAlgo::SetOutputInferTypeAndShape(
|
AnfAlgo::SetOutputInferTypeAndShape(
|
||||||
{kNumberTypeFloat16},
|
{kNumberTypeFloat16}, {{specs.t_size, specs.batch_size, kDimMultiNum * specs.hidden_nz_size * kCubeSize}},
|
||||||
{{origin_input7_shape[kDim0], origin_input7_shape[kDim1], kDimMultiNum * origin_input7_shape[kDim2]}},
|
|
||||||
lstm_gage_concat.get());
|
lstm_gage_concat.get());
|
||||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
||||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
|
||||||
|
@ -484,37 +525,69 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateBatchMatMul2(const FuncGraphPtr &func_
|
||||||
return batch_matmul;
|
return batch_matmul;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CNodePtr DynamicRnnGradFissionV2::CreateTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &dw_reduce_sum,
|
||||||
|
const RNNShapeSpecs &specs) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTranspose->name())),
|
||||||
|
dw_reduce_sum};
|
||||||
|
auto transpose = NewCNode(transpose_inputs, func_graph);
|
||||||
|
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size, kDimMultiNum * specs.hidden_size};
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dw_reduce_sum, 0)}, {out_shape},
|
||||||
|
transpose.get());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{1, 0, 2, 3}), transpose);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), transpose);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), transpose);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_FRAC_NZ}), transpose);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_FRACTAL_ZN_RNN}), transpose);
|
||||||
|
return transpose;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReduceSum(const FuncGraphPtr &func_graph,
|
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReduceSum(const FuncGraphPtr &func_graph,
|
||||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const AnfNodePtr &batch_matmul) const {
|
const AnfNodePtr &batch_matmul,
|
||||||
|
const RNNShapeSpecs &specs) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
// Create node
|
// Create node
|
||||||
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||||
batch_matmul};
|
batch_matmul};
|
||||||
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
|
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
|
||||||
// Set infer data type and shape
|
// Set infer data type and shape
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size,
|
||||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
|
kDimMultiNum * specs.hidden_nz_size * kCubeSize};
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, {out_shape},
|
||||||
|
reduce_sum.get());
|
||||||
// Set attr
|
// Set attr
|
||||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
||||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
||||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
||||||
return reduce_sum;
|
|
||||||
|
auto ret_node = reduce_sum;
|
||||||
|
if (specs.shape_need_align) {
|
||||||
|
ret_node = CreateTranspose(func_graph, reduce_sum, specs);
|
||||||
|
}
|
||||||
|
return ret_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReshape(const FuncGraphPtr &func_graph,
|
AnfNodePtr DynamicRnnGradFissionV2::CreateDwReshape(const FuncGraphPtr &func_graph,
|
||||||
const CNodePtr &dynamic_rnn_grad_cnode,
|
const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const AnfNodePtr &batch_matmul) const {
|
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
// Create node
|
// Create node
|
||||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||||
batch_matmul};
|
batch_matmul};
|
||||||
auto reshape = NewCNode(reshape_inputs, func_graph);
|
auto reshape = NewCNode(reshape_inputs, func_graph);
|
||||||
// Set infer data type and shape
|
// Set infer data type and shape
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
std::vector<size_t> out_shape = {specs.input_size + specs.hidden_size,
|
||||||
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
|
kDimMultiNum * specs.hidden_nz_size * kCubeSize};
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)}, {out_shape},
|
||||||
|
reshape.get());
|
||||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
|
||||||
return reshape;
|
|
||||||
|
auto ret_node = reshape;
|
||||||
|
if (specs.shape_need_align) {
|
||||||
|
ret_node = CreateTranspose(func_graph, reshape, specs);
|
||||||
|
}
|
||||||
|
return ret_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_graph,
|
AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_graph,
|
||||||
|
@ -537,8 +610,8 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateValueNode(const FuncGraphPtr &func_gra
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &,
|
AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &,
|
||||||
const AnfNodePtr &lstm_input_grad,
|
const AnfNodePtr &lstm_input_grad, const AnfNodePtr &value_node,
|
||||||
const AnfNodePtr &value_node) const {
|
const RNNShapeSpecs &specs) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
// Create node
|
// Create node
|
||||||
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
|
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
|
||||||
|
@ -546,12 +619,18 @@ AnfNodePtr DynamicRnnGradFissionV2::CreateDbReduceSum(const FuncGraphPtr &func_g
|
||||||
batch_matmul};
|
batch_matmul};
|
||||||
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
|
auto reduce_sum = NewCNode(reduce_sum_inputs, func_graph);
|
||||||
// Set infer data type and shape
|
// Set infer data type and shape
|
||||||
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
|
std::vector<size_t> out_shape = {kDimMultiNum * specs.hidden_size};
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
|
||||||
// Set attr
|
// Set attr
|
||||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
||||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
||||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
||||||
|
if (specs.shape_need_align) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputSize, MakeValue(SizeToLong(specs.input_size)), reduce_sum);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrHiddenSize, MakeValue(SizeToLong(specs.hidden_size)), reduce_sum);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFixedInputFormat, MakeValue(std::vector<string>{kOpFormat_DEFAULT}), reduce_sum);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFixedOutputFormat, MakeValue(std::vector<string>{kOpFormat_ND_RNN_BIAS}), reduce_sum);
|
||||||
|
}
|
||||||
return reduce_sum;
|
return reduce_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -572,20 +651,28 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsDynamicShape(node)) {
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission.";
|
MS_LOG(INFO) << "DynamicRNNGrad is dynamic shape, can not do fission.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> new_outputs;
|
auto input0_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex1), 0);
|
||||||
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
|
RNNShapeSpecs specs;
|
||||||
|
specs.t_size = input0_shape[0];
|
||||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0];
|
specs.batch_size = input0_shape[1];
|
||||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2];
|
specs.input_size = input0_shape[kDim2];
|
||||||
if (hidden_size % kCubeSize != 0) {
|
specs.hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2];
|
||||||
MS_LOG(EXCEPTION) << "`hidden_size` in this node should be multiple of 16, but got " << hidden_size << ". "
|
if (specs.hidden_size % kCubeSize != 0) {
|
||||||
<< dynamic_rnn_grad_cnode->DebugString();
|
specs.shape_need_align = true;
|
||||||
|
SetAttrInputAndHiddenSize(func_graph, dynamic_rnn_grad_cnode, SizeToLong(specs.input_size),
|
||||||
|
SizeToLong(specs.hidden_size));
|
||||||
}
|
}
|
||||||
|
specs.batch_nz_size = (specs.batch_size + kCubeSize - 1) / kCubeSize;
|
||||||
|
specs.input_nz_size = (specs.input_size + kCubeSize - 1) / kCubeSize;
|
||||||
|
specs.hidden_nz_size = (specs.hidden_size + kCubeSize - 1) / kCubeSize;
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> new_outputs;
|
||||||
|
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, specs, &new_outputs);
|
||||||
AnfNodePtr concat = nullptr;
|
AnfNodePtr concat = nullptr;
|
||||||
if (t_size != 1) {
|
if (specs.t_size != 1) {
|
||||||
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
|
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
|
||||||
auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv);
|
auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv);
|
||||||
concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat);
|
concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat);
|
||||||
|
@ -595,17 +682,17 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
||||||
|
|
||||||
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat);
|
auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat);
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
if (t_size != 1) {
|
if (specs.t_size != 1) {
|
||||||
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
|
auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul, specs);
|
||||||
make_tuple_inputs.emplace_back(dw_reduce_sum);
|
make_tuple_inputs.emplace_back(dw_reduce_sum);
|
||||||
} else {
|
} else {
|
||||||
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
|
auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul, specs);
|
||||||
make_tuple_inputs.emplace_back(dw_reshape);
|
make_tuple_inputs.emplace_back(dw_reshape);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
|
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
|
||||||
// create reduce_sum_2
|
// create reduce_sum_2
|
||||||
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node);
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node, specs);
|
||||||
make_tuple_inputs.emplace_back(db_reduce_sum);
|
make_tuple_inputs.emplace_back(db_reduce_sum);
|
||||||
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
|
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
|
||||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||||
|
|
|
@ -22,6 +22,17 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
struct RNNShapeSpecs {
|
||||||
|
size_t t_size;
|
||||||
|
size_t batch_size;
|
||||||
|
size_t input_size;
|
||||||
|
size_t hidden_size;
|
||||||
|
size_t batch_nz_size;
|
||||||
|
size_t input_nz_size;
|
||||||
|
size_t hidden_nz_size;
|
||||||
|
bool shape_need_align = false;
|
||||||
|
};
|
||||||
|
|
||||||
class DynamicRnnGradFissionV2 : public PatternProcessPass {
|
class DynamicRnnGradFissionV2 : public PatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit DynamicRnnGradFissionV2(bool multigraph = true)
|
explicit DynamicRnnGradFissionV2(bool multigraph = true)
|
||||||
|
@ -32,16 +43,16 @@ class DynamicRnnGradFissionV2 : public PatternProcessPass {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
|
const RNNShapeSpecs &specs, std::vector<std::vector<AnfNodePtr>> *result_nodes) const;
|
||||||
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||||
const std::vector<std::vector<size_t>> &split_shapes,
|
const std::vector<std::vector<size_t>> &split_shapes,
|
||||||
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
|
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
|
||||||
size_t num_split_x) const;
|
size_t num_split_x) const;
|
||||||
void CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
void CreateTLoopNodeWithEdge(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const std::vector<std::vector<AnfNodePtr>> &result_nodes, size_t num_split_x,
|
const std::vector<std::vector<AnfNodePtr>> &result_nodes, size_t num_split_x,
|
||||||
std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const;
|
bool shape_need_align, std::vector<std::vector<AnfNodePtr>> *loop_node_outputs) const;
|
||||||
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
std::vector<AnfNodePtr> *outputs) const;
|
const RNNShapeSpecs &specs, std::vector<AnfNodePtr> *outputs) const;
|
||||||
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
||||||
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const AnfNodePtr &splitv) const;
|
const AnfNodePtr &splitv) const;
|
||||||
|
@ -52,13 +63,15 @@ class DynamicRnnGradFissionV2 : public PatternProcessPass {
|
||||||
const AnfNodePtr &concat) const;
|
const AnfNodePtr &concat) const;
|
||||||
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
||||||
const AnfNodePtr &node) const;
|
const AnfNodePtr &node) const;
|
||||||
|
CNodePtr CreateTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &dw_reduce_sum,
|
||||||
|
const RNNShapeSpecs &specs) const;
|
||||||
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const AnfNodePtr &batch_matmul) const;
|
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const;
|
||||||
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
const AnfNodePtr &batch_matmul) const;
|
const AnfNodePtr &batch_matmul, const RNNShapeSpecs &specs) const;
|
||||||
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) const;
|
||||||
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
|
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
|
||||||
const AnfNodePtr &value_node) const;
|
const AnfNodePtr &value_node, const RNNShapeSpecs &specs) const;
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -739,7 +739,9 @@ std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
|
||||||
} else if (dim_last2 == input_size + hidden_size) {
|
} else if (dim_last2 == input_size + hidden_size) {
|
||||||
device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
|
device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
|
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid. Should be equal to `input_size` or "
|
||||||
|
"`hidden_size` or `input_size + hidden_size`, but got second-last dim value: "
|
||||||
|
<< dim_last2 << " input_size: " << input_size << " hidden_size: " << hidden_size;
|
||||||
}
|
}
|
||||||
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
|
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
|
||||||
device_shape.push_back(NUM16);
|
device_shape.push_back(NUM16);
|
||||||
|
@ -754,8 +756,8 @@ std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &sha
|
||||||
}
|
}
|
||||||
int64_t input_size = input_hidden_size[0];
|
int64_t input_size = input_hidden_size[0];
|
||||||
int64_t hidden_size = input_hidden_size[1];
|
int64_t hidden_size = input_hidden_size[1];
|
||||||
auto dim_last1 = shape[shape.size() - 1];
|
auto dim_last1 = shape[shape.size() - kDim1];
|
||||||
auto dim_last2 = shape[shape.size() - 2];
|
auto dim_last2 = shape[shape.size() - kDim2];
|
||||||
const int64_t NUM16 = 16;
|
const int64_t NUM16 = 16;
|
||||||
const int64_t C0 = SizeToLong(kCubeSize);
|
const int64_t C0 = SizeToLong(kCubeSize);
|
||||||
|
|
||||||
|
@ -767,7 +769,9 @@ std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &sha
|
||||||
} else if (dim_last2 == input_size + hidden_size) {
|
} else if (dim_last2 == input_size + hidden_size) {
|
||||||
device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
|
device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
|
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid. Should be equal to `input_size` or "
|
||||||
|
"`hidden_size` or `input_size + hidden_size` or `-1`, but got second-last dim value: "
|
||||||
|
<< dim_last2 << " input_size: " << input_size << " hidden_size: " << hidden_size;
|
||||||
}
|
}
|
||||||
if (dim_last1 == Shape::SHP_ANY) {
|
if (dim_last1 == Shape::SHP_ANY) {
|
||||||
device_shape[shape.size() - kDim1] = Shape::SHP_ANY;
|
device_shape[shape.size() - kDim1] = Shape::SHP_ANY;
|
||||||
|
@ -857,18 +861,25 @@ int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
|
||||||
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
|
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
|
std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>() && !node->isa<Parameter>()) {
|
||||||
return input_hidden_size;
|
return input_hidden_size;
|
||||||
}
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
if (node->isa<Parameter>()) {
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
|
auto param = node->cast<ParameterPtr>();
|
||||||
MS_LOG(EXCEPTION)
|
input_hidden_size[0] = param->input_size();
|
||||||
<< "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
|
input_hidden_size[1] = param->hidden_size();
|
||||||
<< cnode->DebugString();
|
} else {
|
||||||
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) ||
|
||||||
|
!AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
|
||||||
|
MS_LOG(EXCEPTION)
|
||||||
|
<< "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
|
||||||
|
<< node->DebugString();
|
||||||
|
}
|
||||||
|
input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrInputSize);
|
||||||
|
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrHiddenSize);
|
||||||
}
|
}
|
||||||
input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrInputSize);
|
|
||||||
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
|
|
||||||
return input_hidden_size;
|
return input_hidden_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -587,7 +587,25 @@ void FillNoneInKernelInfo(const CNodePtr &kernel_node, std::vector<kernel::Kerne
|
||||||
(*kernel_info_list)[idx] = builder->Build();
|
(*kernel_info_list)[idx] = builder->Build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ResetPreFixedFormat(const CNodePtr &kernel_node, kernel::KernelBuildInfoPtr *selected_kernel_info) {
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrFixedInputFormat, kernel_node) ||
|
||||||
|
!AnfAlgo::HasNodeAttr(kAttrFixedOutputFormat, kernel_node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(*selected_kernel_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(builder);
|
||||||
|
builder->SetInputsFormat(AnfAlgo::GetNodeAttr<std::vector<string>>(kernel_node, kAttrFixedInputFormat));
|
||||||
|
builder->SetOutputsFormat(AnfAlgo::GetNodeAttr<std::vector<string>>(kernel_node, kAttrFixedOutputFormat));
|
||||||
|
*selected_kernel_info = builder->Build();
|
||||||
|
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
|
||||||
|
<< " selected kernel build info after reset fixed format: " << (*selected_kernel_info)->ToString();
|
||||||
|
AnfAlgo::EraseNodeAttr(kAttrFixedInputFormat, kernel_node);
|
||||||
|
AnfAlgo::EraseNodeAttr(kAttrFixedOutputFormat, kernel_node);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||||
|
@ -613,14 +631,14 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
||||||
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
|
const std::vector<kernel::KernelBuildInfoPtr> &kernel_info_list) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
KernelSelectStatus select_status = kNoMatched;
|
KernelSelectStatus select_status = kNoMatched;
|
||||||
if (kernel_info_list.empty()) {
|
if (kernel_info_list.empty()) {
|
||||||
return select_status;
|
return select_status;
|
||||||
}
|
}
|
||||||
bool precision_reduce = false;
|
bool precision_reduce = false;
|
||||||
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
|
kernel::KernelBuildInfoPtr selected_kernel_info = nullptr;
|
||||||
// Matched kernel info
|
// Matched kernel info
|
||||||
// Filter kernel info matched with me inferred type
|
// Filter kernel info matched with me inferred type
|
||||||
auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
|
auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
|
||||||
|
@ -642,6 +660,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
||||||
// Set kernel build info to node
|
// Set kernel build info to node
|
||||||
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
|
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
|
||||||
<< " selected: " << selected_kernel_info->ToString();
|
<< " selected: " << selected_kernel_info->ToString();
|
||||||
|
ResetPreFixedFormat(kernel_node, &selected_kernel_info);
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
||||||
// Set format and data type for input tensor.
|
// Set format and data type for input tensor.
|
||||||
if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
|
if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
|
||||||
|
|
|
@ -299,6 +299,7 @@ constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
|
||||||
constexpr auto kMatMulOpName = "MatMul";
|
constexpr auto kMatMulOpName = "MatMul";
|
||||||
constexpr auto kMatMulV2OpName = "MatMulV2";
|
constexpr auto kMatMulV2OpName = "MatMulV2";
|
||||||
constexpr auto kBatchMatMulOpName = "BatchMatMul";
|
constexpr auto kBatchMatMulOpName = "BatchMatMul";
|
||||||
|
constexpr auto kBatchMatMulV2OpName = "BatchMatMulV2";
|
||||||
constexpr auto kBroadcastToOpName = "BroadcastTo";
|
constexpr auto kBroadcastToOpName = "BroadcastTo";
|
||||||
constexpr auto kFusedAddReluV2Name = "FusedAddReluV2";
|
constexpr auto kFusedAddReluV2Name = "FusedAddReluV2";
|
||||||
constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2";
|
constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2";
|
||||||
|
@ -493,6 +494,8 @@ constexpr auto kAttrInputSize = "input_size";
|
||||||
constexpr auto kAttrDstType = "dst_type";
|
constexpr auto kAttrDstType = "dst_type";
|
||||||
constexpr auto kAttrDump = "dump";
|
constexpr auto kAttrDump = "dump";
|
||||||
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
|
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
|
||||||
|
constexpr auto kAttrFixedInputFormat = "fixed_input_format";
|
||||||
|
constexpr auto kAttrFixedOutputFormat = "fixed_output_format";
|
||||||
constexpr auto kAttrFuncType = "func_type";
|
constexpr auto kAttrFuncType = "func_type";
|
||||||
constexpr auto kAttrCustAicpu = "cust_aicpu";
|
constexpr auto kAttrCustAicpu = "cust_aicpu";
|
||||||
|
|
||||||
|
|
|
@ -497,6 +497,7 @@ inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>(kAdd);
|
||||||
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||||
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
|
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
|
||||||
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
||||||
|
inline const PrimitivePtr kPrimBatchMatMulV2 = std::make_shared<Primitive>("BatchMatMulV2");
|
||||||
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
||||||
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
|
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
|
||||||
inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce");
|
inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce");
|
||||||
|
|
|
@ -831,14 +831,39 @@ class MS_CORE_API Parameter final : public ANode {
|
||||||
/// \brief Set groups attr in FracZ format.
|
/// \brief Set groups attr in FracZ format.
|
||||||
///
|
///
|
||||||
/// \param[in] fracz_group Groups attr in FracZ format.
|
/// \param[in] fracz_group Groups attr in FracZ format.
|
||||||
void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; }
|
void set_fracz_group(int64_t fracz_group) { format_attrs_.fracz_group = fracz_group; }
|
||||||
|
|
||||||
/// \brief Get groups attr in FracZ format.
|
/// \brief Get groups attr in FracZ format.
|
||||||
///
|
///
|
||||||
/// \return Groups attr in FracZ format.
|
/// \return Groups attr in FracZ format.
|
||||||
int64_t fracz_group() { return fracz_group_; }
|
int64_t fracz_group() { return format_attrs_.fracz_group; }
|
||||||
|
|
||||||
|
/// \brief Set input_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
///
|
||||||
|
/// \param[in] input_size input_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
void set_input_size(int64_t input_size) { format_attrs_.input_size = input_size; }
|
||||||
|
|
||||||
|
/// \brief Get input_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
///
|
||||||
|
/// \return input_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
int64_t input_size() { return format_attrs_.input_size; }
|
||||||
|
|
||||||
|
/// \brief Set hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
///
|
||||||
|
/// \param[in] hidden_size hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
void set_hidden_size(int64_t hidden_size) { format_attrs_.hidden_size = hidden_size; }
|
||||||
|
|
||||||
|
/// \brief Get hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
///
|
||||||
|
/// \return hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
|
||||||
|
int64_t hidden_size() { return format_attrs_.hidden_size; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
struct FormatAttr {
|
||||||
|
int64_t fracz_group = 1;
|
||||||
|
int64_t input_size = 0;
|
||||||
|
int64_t hidden_size = 0;
|
||||||
|
};
|
||||||
std::string name_;
|
std::string name_;
|
||||||
bool has_default_;
|
bool has_default_;
|
||||||
std::set<uint32_t> not_used_in_graphs_;
|
std::set<uint32_t> not_used_in_graphs_;
|
||||||
|
@ -846,8 +871,8 @@ class MS_CORE_API Parameter final : public ANode {
|
||||||
ValuePtr default_param_;
|
ValuePtr default_param_;
|
||||||
// The count of graphs using the parameter.
|
// The count of graphs using the parameter.
|
||||||
int used_graph_count_;
|
int used_graph_count_;
|
||||||
// groups attr in FracZ format
|
// some attrs used in special format
|
||||||
int64_t fracz_group_ = 1;
|
FormatAttr format_attrs_;
|
||||||
bool is_top_graph_param_ = false;
|
bool is_top_graph_param_ = false;
|
||||||
};
|
};
|
||||||
using ParameterPtr = std::shared_ptr<Parameter>;
|
using ParameterPtr = std::shared_ptr<Parameter>;
|
||||||
|
|
|
@ -55,6 +55,7 @@ from .assign_add_ds import _assign_add_ds_tbe
|
||||||
from .assign_sub import _assign_sub_tbe
|
from .assign_sub import _assign_sub_tbe
|
||||||
from .batch_matmul import _batch_matmul_tbe
|
from .batch_matmul import _batch_matmul_tbe
|
||||||
from .batch_matmul_ds import _batch_matmul_ds_tbe
|
from .batch_matmul_ds import _batch_matmul_ds_tbe
|
||||||
|
from .batch_matmul_v2 import _batch_matmul_v2_tbe
|
||||||
from .batchnorm import _batch_norm_tbe
|
from .batchnorm import _batch_norm_tbe
|
||||||
from .batchnorm_grad import _batch_norm_grad_tbe
|
from .batchnorm_grad import _batch_norm_grad_tbe
|
||||||
from .bias_add import _bias_add_tbe
|
from .bias_add import _bias_add_tbe
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BatchMatMul op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
batch_matmul_v2_op_info = TBERegOp("BatchMatMulV2") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batch_matmul_v2.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batch_matmul_v2") \
|
||||||
|
.attr("transpose_x1", "required", "bool", "all") \
|
||||||
|
.attr("transpose_x2", "required", "bool", "all") \
|
||||||
|
.attr("offset_x", "optional", "int", "all", "0") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.need_check_supported(True) \
|
||||||
|
.input(0, "x1", False, "required", "all") \
|
||||||
|
.input(1, "x2", False, "required", "all") \
|
||||||
|
.input(2, "bias", False, "optional", "all") \
|
||||||
|
.input(3, "offset_w", False, "optional", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.is_dynamic_format(True) \
|
||||||
|
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None,
|
||||||
|
DataType.I32_None) \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
|
||||||
|
DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||||
|
DataType.F32_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batch_matmul_v2_op_info)
|
||||||
|
def _batch_matmul_v2_tbe():
|
||||||
|
"""BatchMatMulV2 TBE register"""
|
||||||
|
return
|
Loading…
Reference in New Issue