forked from mindspore-Ecosystem/mindspore
Adapt DynamicGRUV2Grad for Ascend new backend.
This commit is contained in:
parent
e2e532dec3
commit
2aaf5e2e1b
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
|
||||
|
@ -280,6 +281,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicGRUV2GradFission>());
|
||||
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
||||
AddAscendIRFusionPass(ir_fusion_pm.get());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
|
||||
|
|
|
@ -0,0 +1,344 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kDynamicGRUV2GradInputNum = 12;
|
||||
constexpr size_t kDynamicGRUV2GradOutputNum = 6;
|
||||
constexpr size_t kSplitVOutputNum = 2;
|
||||
constexpr size_t kGRUV2HiddenGradOutputNum = 3;
|
||||
|
||||
AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &dynamic_gru_v2_grad_inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> gru_v2_hidden_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kGRUV2HiddenGradOpName)),
|
||||
dynamic_gru_v2_grad_inputs[3],
|
||||
dynamic_gru_v2_grad_inputs[5],
|
||||
dynamic_gru_v2_grad_inputs[6],
|
||||
dynamic_gru_v2_grad_inputs[7],
|
||||
dynamic_gru_v2_grad_inputs[8],
|
||||
dynamic_gru_v2_grad_inputs[9],
|
||||
dynamic_gru_v2_grad_inputs[10],
|
||||
dynamic_gru_v2_grad_inputs[11],
|
||||
dynamic_gru_v2_grad_inputs[12]};
|
||||
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node, kDynamicGRUV2GradOutputNum, &ori_outputs);
|
||||
auto gru_v2_hidden_grad_op = graph->NewCNode(gru_v2_hidden_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(gru_v2_hidden_grad_op);
|
||||
auto h_dtype = AnfAlgo::GetOutputInferDataType(dynamic_gru_v2_grad_inputs[6], 0);
|
||||
auto types = {h_dtype, h_dtype, h_dtype};
|
||||
std::vector<size_t> dh_preh_shape = AnfAlgo::GetOutputInferShape(ori_outputs[5], 0);
|
||||
std::vector<size_t> dgate_h_shape = {AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0)[1],
|
||||
3 * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0)[2]};
|
||||
std::vector<size_t> dnx_t_shape = AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0);
|
||||
auto shapes = {dh_preh_shape, dgate_h_shape, dnx_t_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, gru_v2_hidden_grad_op.get());
|
||||
auto gate_order = AnfAlgo::GetNodeAttr<std::string>(cnode, "gate_order");
|
||||
AnfAlgo::SetNodeAttr("gate_order", MakeValue(gate_order), gru_v2_hidden_grad_op);
|
||||
return gru_v2_hidden_grad_op;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// SplitV
|
||||
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node};
|
||||
auto split_vd = graph->NewCNode(splitvd_input);
|
||||
MS_EXCEPTION_IF_NULL(split_vd);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[2];
|
||||
std::vector<size_t> shape = {t_size - IntToSize(1), batch, hidden_size};
|
||||
std::vector<size_t> shape2 = {IntToSize(1), batch, hidden_size};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(0)), split_vd);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(2)), split_vd);
|
||||
std::vector<int64_t> size_splits = {SizeToLong(t_size - 1), SizeToLong(1)};
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd);
|
||||
return split_vd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto ori_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
std::vector<std::vector<size_t>> shape_tmp;
|
||||
if (ori_shape.size() == 3) {
|
||||
shape_tmp = {ori_shape};
|
||||
} else {
|
||||
shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}};
|
||||
}
|
||||
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
// reshape
|
||||
std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), node};
|
||||
auto reshape = graph->NewCNode(reshape_input);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get());
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
|
||||
return reshape;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateHConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs);
|
||||
auto reshape = CreateHReshape(graph, node1);
|
||||
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
reshape, ori_outputs[0]};
|
||||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(0)), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// SplitV
|
||||
std::vector<AnfNodePtr> splitvd_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())), node};
|
||||
auto split_vd = graph->NewCNode(splitvd_input);
|
||||
MS_EXCEPTION_IF_NULL(split_vd);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[2] / 3;
|
||||
std::vector<size_t> shape = {t_size, batch, 2 * hidden_size};
|
||||
std::vector<size_t> shape2 = {t_size, batch, hidden_size};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(2)), split_vd);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(2)), split_vd);
|
||||
std::vector<int64_t> size_splits = {2 * SizeToLong(hidden_size), SizeToLong(hidden_size)};
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd);
|
||||
return split_vd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// node1: dgate_h_split
|
||||
// node2: dnt_x
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> ori_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, node1, 2, &ori_outputs);
|
||||
|
||||
// ConcatD
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
|
||||
ori_outputs[0], node2};
|
||||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(2)), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// node1 : input node
|
||||
// node2 : orign_input x
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// BroadcastTo
|
||||
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node1};
|
||||
auto broadcast_to_d = graph->NewCNode(braodcast_to_input);
|
||||
MS_EXCEPTION_IF_NULL(broadcast_to_d);
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[0];
|
||||
size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[1];
|
||||
std::vector<size_t> shape = {t_size, batch, gate_size};
|
||||
auto type = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get());
|
||||
|
||||
std::vector<int64_t> attr_shape = {SizeToLong(t_size), SizeToLong(batch), SizeToLong(gate_size)};
|
||||
AnfAlgo::SetNodeAttr("shape", MakeValue(attr_shape), broadcast_to_d);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), broadcast_to_d);
|
||||
return broadcast_to_d;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDhxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// BatchMatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
node1, node2};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[2],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwhBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// BatchMatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
||||
node1, node2};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
return batch_matmul;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// ReduceSumD for dw_x and dw_h
|
||||
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node};
|
||||
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node2, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
// ReduceSumD for db_x and db_h
|
||||
std::vector<AnfNodePtr> reducesum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
||||
node};
|
||||
auto reduce_sumd = graph->NewCNode(reducesum_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
std::vector<size_t> shape = {3 * AnfAlgo::GetOutputInferShape(node2, 0)[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd);
|
||||
return reduce_sumd;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DynamicGRUV2GradFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimDynamicGRUV2Grad, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dynamic_gru_v2_grad_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode);
|
||||
if (dynamic_gru_v2_grad_cnode->size() < kDynamicGRUV2GradInputNum + 1) {
|
||||
MS_LOG(INFO) << "The node " << dynamic_gru_v2_grad_cnode->DebugString() << " has less than "
|
||||
<< kDynamicGRUV2GradInputNum << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// input_list of dynamic_gru_v2_grad
|
||||
const auto &ori_inputs = dynamic_gru_v2_grad_cnode->inputs();
|
||||
// add gru_v2_gru_hidden
|
||||
auto gru_v2_gru_hidden = CreateGRUV2HiddenGradNode(func_graph, dynamic_gru_v2_grad_cnode);
|
||||
std::vector<AnfNodePtr> gru_hidden_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, gru_v2_gru_hidden, kGRUV2HiddenGradOutputNum, &gru_hidden_outputs);
|
||||
size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[1], 0)[0];
|
||||
AnfNodePtr dwh_batch_matmul = nullptr;
|
||||
if (step_num != 1) {
|
||||
// split h
|
||||
auto h_split = CreateHSplitVDNode(func_graph, ori_inputs[6]);
|
||||
// concat(h, h_split)
|
||||
auto h_concat = CreateHConcatDNode(func_graph, ori_inputs[5], h_split);
|
||||
// batchmatmul(h_concat.T, dgate_h)
|
||||
dwh_batch_matmul = CreateDhxBatchMatMul(func_graph, h_concat, gru_hidden_outputs[1]);
|
||||
} else {
|
||||
auto reshape = CreateHReshape(func_graph, ori_inputs[5]);
|
||||
// batchmatmul(init_h.T, dgate_h)
|
||||
dwh_batch_matmul = CreateDhxBatchMatMul(func_graph, reshape, gru_hidden_outputs[1]);
|
||||
}
|
||||
// split dgate_h
|
||||
auto dgate_h_split = CreateDgateHSplitVDNode(func_graph, gru_hidden_outputs[1]);
|
||||
// concat(dgate_h_split[0], dnt_x) to dgate_x
|
||||
auto dgate_x_concat = CreateDgateXConcatDNode(func_graph, dgate_h_split, gru_hidden_outputs[2]);
|
||||
// broadcast weight_input [input_size, 3 * hidden_size] to [t_size, input_size, 3 * hidden_size]
|
||||
auto w_input_broadcast = CreateWBroadcastToDNode(func_graph, ori_inputs[2], ori_inputs[1]);
|
||||
// batchmatmul(x.T, dgate_x_concat)
|
||||
auto dwx_batch_matmul = CreateDhxBatchMatMul(func_graph, ori_inputs[1], dgate_x_concat);
|
||||
// batchmatmul(dgate_x_concat, w_input_broadcast.T)
|
||||
auto dxt_batch_matmul = CreateDwhBatchMatMul(func_graph, dgate_x_concat, w_input_broadcast);
|
||||
// reducesum dw_x and dw_h
|
||||
auto dwx_reduce_sum = CreateDwReduceSumDNode(func_graph, dwx_batch_matmul, ori_inputs[2]);
|
||||
auto dwh_reduce_sum = CreateDwReduceSumDNode(func_graph, dwh_batch_matmul, ori_inputs[3]);
|
||||
// reducesum db_x and db_h
|
||||
auto dbx_reduce_sum = CreateDbReduceSumDNode(func_graph, dgate_x_concat, ori_inputs[5]);
|
||||
auto dbh_reduce_sum = CreateDbReduceSumDNode(func_graph, gru_hidden_outputs[1], ori_inputs[5]);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple),
|
||||
dwx_reduce_sum,
|
||||
dwh_reduce_sum,
|
||||
dbx_reduce_sum,
|
||||
dbh_reduce_sum,
|
||||
dxt_batch_matmul,
|
||||
gru_hidden_outputs[0]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DynamicGRUV2GradFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit DynamicGRUV2GradFission(bool multigraph = true)
|
||||
: PatternProcessPass("dynamic_gru_grad_v2_fission", multigraph) {}
|
||||
~DynamicGRUV2GradFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
|
|
@ -1157,7 +1157,7 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
|
|||
reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words. Tensor of shape :math:`({num_step, batch_size, input_size)`.
|
||||
- **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
|
@ -1168,17 +1168,17 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
|
|||
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
Tensor of shape :math:`(batch_size, hidden_size)`, or None.
|
||||
Tensor of shape :math:`(batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **h** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
- **h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
|
||||
- **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `h`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
- **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
|
|
|
@ -492,7 +492,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
|
||||
Only `None` is currently supported.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None.
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`.
|
||||
The data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
|
@ -511,10 +511,9 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
|
||||
- If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32.
|
||||
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32.
|
||||
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
|
||||
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
|
||||
- Otherwise, `bias_type` is the date type of `init_h`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
||||
|
@ -553,8 +552,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape,
|
||||
binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None):
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
||||
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
||||
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
||||
|
@ -564,7 +562,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
if winput_shape[-1] % 3 != 0:
|
||||
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
|
||||
|
||||
self.placeholder_index = [3, 4, 5, 6]
|
||||
self.placeholder_index = [3, 4, 5]
|
||||
if binput_shape is not None:
|
||||
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
||||
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
|
@ -574,14 +572,12 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
validator.check("bias_hidden_shape", bhidden_shape,
|
||||
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(4)
|
||||
if h_shape is not None:
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(6)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
||||
whidden_shape[-1], Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
||||
|
@ -590,15 +586,15 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
|
||||
else:
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
outh_shape = (num_step, batch_size, hidden_size)
|
||||
out_shape = (num_step, batch_size, hidden_size)
|
||||
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
||||
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
|
||||
return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype,
|
||||
binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None):
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
|
||||
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = mstype.float32
|
||||
if binput_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
|
||||
|
@ -608,10 +604,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = bhidden_dtype
|
||||
elif h_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = h_dtype
|
||||
|
||||
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
|
|
|
@ -2532,7 +2532,11 @@ test_case_other_ops = [
|
|||
Tensor(np.random.rand(48).astype(np.float16)),
|
||||
Tensor(np.random.rand(48).astype(np.float16)),
|
||||
Tensor(np.random.rand(8, 16).astype(np.float16))],
|
||||
'skip': ['backward']}),
|
||||
'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}),
|
||||
]
|
||||
|
||||
test_case_quant_ops = [
|
||||
|
|
Loading…
Reference in New Issue