Adapt DynamicGRUV2 forward for Ascend new backend.

This commit is contained in:
liuxiao93 2020-11-12 11:03:50 +08:00
parent 5f7a9bd0b8
commit d471ac491e
16 changed files with 390 additions and 54 deletions

View File

@ -147,6 +147,39 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspor
return true;
}
void GenNoneInputDescJson(const std::shared_ptr<OpIOInfo> &input_ptr, size_t input_i,
std::vector<nlohmann::json> *input_list) {
nlohmann::json input_desc_json;
auto in_name = input_ptr->name();
input_desc_json[kJName] = in_name + std::to_string(input_i);
input_desc_json[kJValid] = false;
input_list->emplace_back(input_desc_json);
}
void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
const string &op_input_name, size_t input_i,
std::vector<nlohmann::json> *input_list) {
auto dtype = GetDeviceInputType(anf_node, real_input_index);
auto format = GetDeviceInputFormat(anf_node, real_input_index);
auto shape = GetDeviceInputShape(anf_node, real_input_index);
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
nlohmann::json input_desc_json;
input_desc_json[kJDtype] = dtype;
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
input_desc_json[kJOriShape] = ori_shape;
input_desc_json[kJOriFormat] = kOpFormat_NCHW;
input_desc_json[kJShape] = shape;
input_desc_json[kJFormat] = format;
input_desc_json[kJValid] = value;
input_desc_json[kJParamType] = input_ptr->param_type();
input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index);
input_list->emplace_back(input_desc_json);
}
bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
const string &op_input_name, size_t input_i,
@ -156,32 +189,19 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_
MS_EXCEPTION_IF_NULL(input_list);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == kDynamicRNNOpName && input_ptr->name() == "seq_length") {
nlohmann::json input_desc_json;
auto in_name = input_ptr->name();
input_desc_json[kJName] = in_name + std::to_string(input_i);
input_desc_json[kJValid] = false;
input_list->emplace_back(input_desc_json);
GenNoneInputDescJson(input_ptr, input_i, input_list);
} else if (op_name == kDynamicGRUV2OpName) {
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node, "placeholder_index");
auto item = find(none_index.begin(), none_index.end(), input_ptr->index());
if (item != none_index.end()) {
GenNoneInputDescJson(input_ptr, input_i, input_list);
} else {
GenValidInputDescJson(anf_node, real_input_index, value, input_ptr, op_input_name, input_i, input_list);
}
} else if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
} else {
auto dtype = GetDeviceInputType(anf_node, real_input_index);
auto format = GetDeviceInputFormat(anf_node, real_input_index);
auto shape = GetDeviceInputShape(anf_node, real_input_index);
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
nlohmann::json input_desc_json;
input_desc_json[kJDtype] = dtype;
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
input_desc_json[kJOriShape] = ori_shape;
input_desc_json[kJOriFormat] = kOpFormat_NCHW;
input_desc_json[kJShape] = shape;
input_desc_json[kJFormat] = format;
input_desc_json[kJValid] = value;
input_desc_json[kJParamType] = input_ptr->param_type();
input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index);
input_list->emplace_back(input_desc_json);
GenValidInputDescJson(anf_node, real_input_index, value, input_ptr, op_input_name, input_i, input_list);
}
return true;
}

View File

@ -111,6 +111,9 @@ class TbeKernelJsonCreator {
void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
std::vector<nlohmann::json> *output_list);
void GenValidInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
std::vector<nlohmann::json> *input_list);
std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;

View File

@ -63,6 +63,7 @@
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
@ -110,6 +111,7 @@
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
#include "utils/ms_context.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
@ -222,6 +224,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>());
data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>());
data_layout_pm->AddPass(std::make_shared<InsertTransposeForDynamicGRUV2>());
data_layout_pm->AddPass(std::make_shared<OptimizeDependence>());
data_layout_pm->AddPass(std::make_shared<TransDataSplit>());
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
@ -278,6 +281,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
#include <vector>
#include <memory>
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "abstract/abstract_value.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
const BaseRef InsertPlaceholderForDynamicGRUV2::DefinePattern() const {
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr InsertPlaceholderForDynamicGRUV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name != kDynamicGRUV2OpName) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
size_t input_num = AnfAlgo::GetInputTensorNum(node);
if (input_num == 0) {
return nullptr;
}
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
size_t real_input_index = 0;
for (size_t in_idx = 0; in_idx < input_num + none_index.size(); in_idx++) {
auto item = find(none_index.begin(), none_index.end(), in_idx);
if (item != none_index.end()) {
auto value = std::make_shared<None>();
auto value_node = NewValueNode(value);
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
auto new_node = kernel_graph->NewValueNode(value_node);
kernel_graph->AddValueNodeToGraph(new_node);
new_inputs.push_back(new_node);
} else {
auto input_node = AnfAlgo::GetInputNode(cnode, real_input_index);
new_inputs.push_back(input_node);
real_input_index++;
}
}
CNodePtr new_node = nullptr;
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
}
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);
return new_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertPlaceholderForDynamicGRUV2 : public PatternProcessPass {
public:
explicit InsertPlaceholderForDynamicGRUV2(bool multigraph = true)
: PatternProcessPass("add_placeholder_for_dynamic_gru", multigraph) {}
~InsertPlaceholderForDynamicGRUV2() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_
#include <string>
#include <utility>
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertTransposeForDynamicGRUV2 : public PatternProcessPass {
public:
explicit InsertTransposeForDynamicGRUV2(bool multigraph = true)
: PatternProcessPass("insert_transpose_for_dynamic_gru_v2_op", multigraph) {}
~InsertTransposeForDynamicGRUV2() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_

View File

@ -0,0 +1,94 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
const BaseRef InsertTransposeForDynamicGRUV2::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr X1 = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(X);
MS_EXCEPTION_IF_NULL(X1);
MS_EXCEPTION_IF_NULL(Xs);
return VectorRef(
{prim::kPrimDynamicGRUV2, X1, VectorRef({prim::KPrimTransData, VectorRef({prim::kPrimReshape, X})}), Xs});
}
CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
for (size_t index = 0; index < cnode->inputs().size(); index++) {
if (index == 1 || index == 2) {
AnfNodePtr new_node = nullptr;
AnfNodePtr new_transdata_node = nullptr;
AnfNodePtr new_transpose_node = nullptr;
AnfNodePtr transdata_node = AnfAlgo::GetInputNode(cnode, index);
AnfNodePtr reshape_node = AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0);
auto input_format = AnfAlgo::GetInputFormat(transdata_node, 0);
auto output_format = AnfAlgo::GetOutputFormat(transdata_node, 0);
auto padding_axis = AnfAlgo::GetOutputReshapeType(transdata_node, 0);
KernelSelectPtr kernel_select = std::make_shared<KernelSelect>();
// trans default to hwcn
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0),
kernel_select, false, prim::kPrimTranspose->name());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{2, 3, 1, 0}), new_transpose_node);
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
// trans hwcn to output_format
new_transdata_node =
NewTransOpNode(func_graph, new_transpose_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis);
new_transdata_node->set_abstract(transdata_node->abstract());
new_node = new_transdata_node;
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
if (!manager->Replace(transdata_node, new_node)) {
MS_LOG(EXCEPTION) << "For DynamicGRUV2, manager replace node failed";
}
}
}
return cnode;
}
const AnfNodePtr InsertTransposeForDynamicGRUV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
CNodePtr new_node = nullptr;
if (op_name == kDynamicGRUV2OpName) {
new_node = Insert(func_graph, cnode);
}
return new_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -59,7 +59,7 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn
auto name = AnfAlgo::GetCNodeName(cnode);
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
if (name == kDynamicRNNOpName && input_origin_type == kMetaTypeNone) {
if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
continue;
}
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {

View File

@ -133,6 +133,13 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
if (op_name == kDynamicRNNOpName && i == 3) {
continue;
}
if (op_name == kDynamicGRUV2OpName) {
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node_ptr, "placeholder_index");
auto item = find(none_index.begin(), none_index.end(), i);
if (item != none_index.end()) {
continue;
}
}
auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index);
AddressPtr input = std::make_shared<Address>();

View File

@ -227,8 +227,8 @@ constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
constexpr auto kDynamicRNNOpName = "DynamicRNN";
constexpr auto kLSTMInputGradOpName = "LSTMInputGrad";
constexpr auto kDynamicGRUOpName = "DynamicGRU";
constexpr auto kGRUV2HiddenGrad = "GRUV2HiddenGrad";
constexpr auto kDynamicGRUV2OpName = "DynamicGRUV2";
constexpr auto kGRUV2HiddenGradOpName = "GRUV2HiddenGrad";
constexpr auto kFusedSparseFtrlName = "FusedSparseFtrl";
constexpr auto kFusedSparseProximalAdagradName = "FusedSparseProximalAdagrad";
constexpr auto kFusedSparseLazyAdamName = "FusedSparseLazyAdam";
@ -239,6 +239,7 @@ constexpr auto kLARSUpdateName = "LARSUpdate";
constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad";
constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
constexpr auto kMatMulV2OpName = "MatMulV2";
constexpr auto kBroadcastToOpName = "BroadcastTo";
// Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";

View File

@ -34,8 +34,8 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \
.attr("is_training", "optional", "bool", "all", "true") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "weight_input", False, "required", "all") \
.input(2, "weight_hidden", False, "required", "all") \
.input(1, "weight_input", False, "required", "all", reshape_type="CN") \
.input(2, "weight_hidden", False, "required", "all", reshape_type="CN") \
.input(3, "bias_input", False, "optional", "all") \
.input(4, "bias_hidden", False, "optional", "all") \
.input(5, "seq_length", False, "optional", "all") \

View File

@ -22,7 +22,7 @@ gru_v2_hidden_grad_op_info = TBERegOp("GRUV2HiddenGrad") \
.binfile_name("gru_v2_hidden_grad.so") \
.compute_cost(10) \
.kernel_name("gru_v2_hidden_grad") \
.attr("gate_order", "optional", "str", "all", "zrh") \
.attr("gate_order", "optional", "str", "all", "rzh") \
.partial_flag(True) \
.input(0, "weight_input", False, "required", "all") \
.input(1, "init_h", False, "required", "all") \

View File

@ -1210,7 +1210,7 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
num_proj=0,
time_major=True,
bias_type="double_bias",
gate_order="zrh",
gate_order="rzh",
reset_after=True):
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
@ -1266,12 +1266,13 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype,
dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype,
"dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype,
"new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
if seq_dtype is not None:
validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)

View File

@ -549,41 +549,49 @@ class DynamicGRUV2(PrimitiveWithInfer):
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
def infer_shape(self, x_shape, winput_shape, whidden_shape,
binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None):
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
if binput_shape is not None:
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
if bhidden_shape is not None:
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
if h_shape is not None:
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
num_step, batch_size, input_size = x_shape
hidden_size = winput_shape[-1] // 3
if winput_shape[-1] % 3 != 0:
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
whidden_shape[-1], Rel.EQ, self.name)
validator.check("bias_input_shape", binput_shape, "bias_hidden_shape", bhidden_shape, Rel.EQ, self.name)
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
self.placeholder_index = [3, 4, 5, 6]
if binput_shape is not None:
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
self.placeholder_index.remove(3)
if bhidden_shape is not None:
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
validator.check("bias_hidden_shape", bhidden_shape,
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
self.placeholder_index.remove(4)
if h_shape is not None:
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
self.placeholder_index.remove(6)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
whidden_shape[-1], Rel.EQ, self.name)
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
if self.num_proj > 0:
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
else:
y_shape = (num_step, batch_size, hidden_size)
outh_shape = (num_step, batch_size, hidden_size)
self.add_prim_attr("placeholder_index", self.placeholder_index)
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype,
binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None):
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class DynamicGRUV2(nn.Cell):
def __init__(self):
super(DynamicGRUV2, self).__init__()
self.dynamic_gru = P.DynamicGRUV2()
def construct(self, x, weight_i, weight_h, bias_i, bias_h, init_h):
return self.dynamic_gru(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
def test_dynamic_gru_v2():
x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
weight_h = Tensor(np.random.rand(16, 48).astype(np.float16))
bias_i = Tensor(np.random.rand(48).astype(np.float16))
bias_h = Tensor(np.random.rand(48).astype(np.float16))
init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
gru_net = DynamicGRUV2()
output = gru_net(x, weight_i, weight_h, bias_i, bias_h, init_h)
print(output)

View File

@ -2532,11 +2532,7 @@ test_case_other_ops = [
Tensor(np.random.rand(48).astype(np.float16)),
Tensor(np.random.rand(48).astype(np.float16)),
Tensor(np.random.rand(8, 16).astype(np.float16))],
'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}),
'skip': ['backward']}),
]
test_case_quant_ops = [