forked from mindspore-Ecosystem/mindspore
Adapt DynamicGRUV2 forward for Ascend new backend.
This commit is contained in:
parent
5f7a9bd0b8
commit
d471ac491e
|
@ -147,6 +147,39 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspor
|
|||
return true;
|
||||
}
|
||||
|
||||
void GenNoneInputDescJson(const std::shared_ptr<OpIOInfo> &input_ptr, size_t input_i,
|
||||
std::vector<nlohmann::json> *input_list) {
|
||||
nlohmann::json input_desc_json;
|
||||
auto in_name = input_ptr->name();
|
||||
input_desc_json[kJName] = in_name + std::to_string(input_i);
|
||||
input_desc_json[kJValid] = false;
|
||||
input_list->emplace_back(input_desc_json);
|
||||
}
|
||||
|
||||
void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
|
||||
bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
|
||||
const string &op_input_name, size_t input_i,
|
||||
std::vector<nlohmann::json> *input_list) {
|
||||
auto dtype = GetDeviceInputType(anf_node, real_input_index);
|
||||
auto format = GetDeviceInputFormat(anf_node, real_input_index);
|
||||
auto shape = GetDeviceInputShape(anf_node, real_input_index);
|
||||
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
nlohmann::json input_desc_json;
|
||||
input_desc_json[kJDtype] = dtype;
|
||||
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
|
||||
input_desc_json[kJOriShape] = ori_shape;
|
||||
input_desc_json[kJOriFormat] = kOpFormat_NCHW;
|
||||
input_desc_json[kJShape] = shape;
|
||||
input_desc_json[kJFormat] = format;
|
||||
input_desc_json[kJValid] = value;
|
||||
input_desc_json[kJParamType] = input_ptr->param_type();
|
||||
input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index);
|
||||
input_list->emplace_back(input_desc_json);
|
||||
}
|
||||
|
||||
bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
|
||||
bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
|
||||
const string &op_input_name, size_t input_i,
|
||||
|
@ -156,32 +189,19 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_
|
|||
MS_EXCEPTION_IF_NULL(input_list);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (op_name == kDynamicRNNOpName && input_ptr->name() == "seq_length") {
|
||||
nlohmann::json input_desc_json;
|
||||
auto in_name = input_ptr->name();
|
||||
input_desc_json[kJName] = in_name + std::to_string(input_i);
|
||||
input_desc_json[kJValid] = false;
|
||||
input_list->emplace_back(input_desc_json);
|
||||
GenNoneInputDescJson(input_ptr, input_i, input_list);
|
||||
} else if (op_name == kDynamicGRUV2OpName) {
|
||||
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node, "placeholder_index");
|
||||
auto item = find(none_index.begin(), none_index.end(), input_ptr->index());
|
||||
if (item != none_index.end()) {
|
||||
GenNoneInputDescJson(input_ptr, input_i, input_list);
|
||||
} else {
|
||||
GenValidInputDescJson(anf_node, real_input_index, value, input_ptr, op_input_name, input_i, input_list);
|
||||
}
|
||||
} else if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
|
||||
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
|
||||
} else {
|
||||
auto dtype = GetDeviceInputType(anf_node, real_input_index);
|
||||
auto format = GetDeviceInputFormat(anf_node, real_input_index);
|
||||
auto shape = GetDeviceInputShape(anf_node, real_input_index);
|
||||
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
nlohmann::json input_desc_json;
|
||||
input_desc_json[kJDtype] = dtype;
|
||||
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
|
||||
input_desc_json[kJOriShape] = ori_shape;
|
||||
input_desc_json[kJOriFormat] = kOpFormat_NCHW;
|
||||
input_desc_json[kJShape] = shape;
|
||||
input_desc_json[kJFormat] = format;
|
||||
input_desc_json[kJValid] = value;
|
||||
input_desc_json[kJParamType] = input_ptr->param_type();
|
||||
input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index);
|
||||
input_list->emplace_back(input_desc_json);
|
||||
GenValidInputDescJson(anf_node, real_input_index, value, input_ptr, op_input_name, input_i, input_list);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -111,6 +111,9 @@ class TbeKernelJsonCreator {
|
|||
void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
|
||||
const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
|
||||
std::vector<nlohmann::json> *output_list);
|
||||
void GenValidInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
|
||||
const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
|
||||
std::vector<nlohmann::json> *input_list);
|
||||
std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||
std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||
std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||
|
|
|
@ -63,6 +63,7 @@
|
|||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
|
||||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||
|
@ -110,6 +111,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
|
@ -222,6 +224,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>());
|
||||
data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
data_layout_pm->AddPass(std::make_shared<InsertTransposeForDynamicGRUV2>());
|
||||
data_layout_pm->AddPass(std::make_shared<OptimizeDependence>());
|
||||
data_layout_pm->AddPass(std::make_shared<TransDataSplit>());
|
||||
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
|
@ -278,6 +281,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
|
||||
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
||||
AddAscendIRFusionPass(ir_fusion_pm.get());
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef InsertPlaceholderForDynamicGRUV2::DefinePattern() const {
|
||||
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertPlaceholderForDynamicGRUV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name != kDynamicGRUV2OpName) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
if (input_num == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
|
||||
size_t real_input_index = 0;
|
||||
for (size_t in_idx = 0; in_idx < input_num + none_index.size(); in_idx++) {
|
||||
auto item = find(none_index.begin(), none_index.end(), in_idx);
|
||||
if (item != none_index.end()) {
|
||||
auto value = std::make_shared<None>();
|
||||
auto value_node = NewValueNode(value);
|
||||
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
auto new_node = kernel_graph->NewValueNode(value_node);
|
||||
kernel_graph->AddValueNodeToGraph(new_node);
|
||||
new_inputs.push_back(new_node);
|
||||
} else {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode, real_input_index);
|
||||
new_inputs.push_back(input_node);
|
||||
real_input_index++;
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr new_node = nullptr;
|
||||
if (kernel_graph == nullptr) {
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_node = kernel_graph->NewCNode(cnode);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertPlaceholderForDynamicGRUV2 : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertPlaceholderForDynamicGRUV2(bool multigraph = true)
|
||||
: PatternProcessPass("add_placeholder_for_dynamic_gru", multigraph) {}
|
||||
~InsertPlaceholderForDynamicGRUV2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertTransposeForDynamicGRUV2 : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertTransposeForDynamicGRUV2(bool multigraph = true)
|
||||
: PatternProcessPass("insert_transpose_for_dynamic_gru_v2_op", multigraph) {}
|
||||
~InsertTransposeForDynamicGRUV2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef InsertTransposeForDynamicGRUV2::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr X1 = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
MS_EXCEPTION_IF_NULL(X1);
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
return VectorRef(
|
||||
{prim::kPrimDynamicGRUV2, X1, VectorRef({prim::KPrimTransData, VectorRef({prim::kPrimReshape, X})}), Xs});
|
||||
}
|
||||
|
||||
CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
for (size_t index = 0; index < cnode->inputs().size(); index++) {
|
||||
if (index == 1 || index == 2) {
|
||||
AnfNodePtr new_node = nullptr;
|
||||
AnfNodePtr new_transdata_node = nullptr;
|
||||
AnfNodePtr new_transpose_node = nullptr;
|
||||
AnfNodePtr transdata_node = AnfAlgo::GetInputNode(cnode, index);
|
||||
AnfNodePtr reshape_node = AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0);
|
||||
auto input_format = AnfAlgo::GetInputFormat(transdata_node, 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(transdata_node, 0);
|
||||
auto padding_axis = AnfAlgo::GetOutputReshapeType(transdata_node, 0);
|
||||
KernelSelectPtr kernel_select = std::make_shared<KernelSelect>();
|
||||
// trans default to hwcn
|
||||
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0),
|
||||
kernel_select, false, prim::kPrimTranspose->name());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{2, 3, 1, 0}), new_transpose_node);
|
||||
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
|
||||
// trans hwcn to output_format
|
||||
new_transdata_node =
|
||||
NewTransOpNode(func_graph, new_transpose_node, kernel_select, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis);
|
||||
new_transdata_node->set_abstract(transdata_node->abstract());
|
||||
new_node = new_transdata_node;
|
||||
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
if (!manager->Replace(transdata_node, new_node)) {
|
||||
MS_LOG(EXCEPTION) << "For DynamicGRUV2, manager replace node failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertTransposeForDynamicGRUV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
CNodePtr new_node = nullptr;
|
||||
if (op_name == kDynamicGRUV2OpName) {
|
||||
new_node = Insert(func_graph, cnode);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -59,7 +59,7 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn
|
|||
auto name = AnfAlgo::GetCNodeName(cnode);
|
||||
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
||||
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
if (name == kDynamicRNNOpName && input_origin_type == kMetaTypeNone) {
|
||||
if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
|
||||
continue;
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
|
||||
|
|
|
@ -133,6 +133,13 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
|||
if (op_name == kDynamicRNNOpName && i == 3) {
|
||||
continue;
|
||||
}
|
||||
if (op_name == kDynamicGRUV2OpName) {
|
||||
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node_ptr, "placeholder_index");
|
||||
auto item = find(none_index.begin(), none_index.end(), i);
|
||||
if (item != none_index.end()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index);
|
||||
AddressPtr input = std::make_shared<Address>();
|
||||
|
|
|
@ -227,8 +227,8 @@ constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
|
|||
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
|
||||
constexpr auto kDynamicRNNOpName = "DynamicRNN";
|
||||
constexpr auto kLSTMInputGradOpName = "LSTMInputGrad";
|
||||
constexpr auto kDynamicGRUOpName = "DynamicGRU";
|
||||
constexpr auto kGRUV2HiddenGrad = "GRUV2HiddenGrad";
|
||||
constexpr auto kDynamicGRUV2OpName = "DynamicGRUV2";
|
||||
constexpr auto kGRUV2HiddenGradOpName = "GRUV2HiddenGrad";
|
||||
constexpr auto kFusedSparseFtrlName = "FusedSparseFtrl";
|
||||
constexpr auto kFusedSparseProximalAdagradName = "FusedSparseProximalAdagrad";
|
||||
constexpr auto kFusedSparseLazyAdamName = "FusedSparseLazyAdam";
|
||||
|
@ -239,6 +239,7 @@ constexpr auto kLARSUpdateName = "LARSUpdate";
|
|||
constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad";
|
||||
constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
|
||||
constexpr auto kMatMulV2OpName = "MatMulV2";
|
||||
constexpr auto kBroadcastToOpName = "BroadcastTo";
|
||||
|
||||
// Hcom Op Type
|
||||
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
|
||||
|
|
|
@ -34,8 +34,8 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \
|
|||
.attr("is_training", "optional", "bool", "all", "true") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "weight_input", False, "required", "all") \
|
||||
.input(2, "weight_hidden", False, "required", "all") \
|
||||
.input(1, "weight_input", False, "required", "all", reshape_type="CN") \
|
||||
.input(2, "weight_hidden", False, "required", "all", reshape_type="CN") \
|
||||
.input(3, "bias_input", False, "optional", "all") \
|
||||
.input(4, "bias_hidden", False, "optional", "all") \
|
||||
.input(5, "seq_length", False, "optional", "all") \
|
||||
|
|
|
@ -22,7 +22,7 @@ gru_v2_hidden_grad_op_info = TBERegOp("GRUV2HiddenGrad") \
|
|||
.binfile_name("gru_v2_hidden_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("gru_v2_hidden_grad") \
|
||||
.attr("gate_order", "optional", "str", "all", "zrh") \
|
||||
.attr("gate_order", "optional", "str", "all", "rzh") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "weight_input", False, "required", "all") \
|
||||
.input(1, "init_h", False, "required", "all") \
|
||||
|
|
|
@ -1210,7 +1210,7 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
|
|||
num_proj=0,
|
||||
time_major=True,
|
||||
bias_type="double_bias",
|
||||
gate_order="zrh",
|
||||
gate_order="rzh",
|
||||
reset_after=True):
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
|
@ -1266,12 +1266,13 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
|
|||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype,
|
||||
dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
|
||||
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
|
||||
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
|
||||
args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype,
|
||||
"dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype,
|
||||
"new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
|
||||
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
|
||||
validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
|
||||
validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
|
||||
validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name)
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
|
||||
if seq_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
|
||||
|
|
|
@ -549,41 +549,49 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape,
|
||||
binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
||||
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
||||
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
||||
if binput_shape is not None:
|
||||
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
||||
if bhidden_shape is not None:
|
||||
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
|
||||
if h_shape is not None:
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = winput_shape[-1] // 3
|
||||
|
||||
if winput_shape[-1] % 3 != 0:
|
||||
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
|
||||
|
||||
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
||||
whidden_shape[-1], Rel.EQ, self.name)
|
||||
validator.check("bias_input_shape", binput_shape, "bias_hidden_shape", bhidden_shape, Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
||||
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
self.placeholder_index = [3, 4, 5, 6]
|
||||
if binput_shape is not None:
|
||||
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
||||
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(3)
|
||||
if bhidden_shape is not None:
|
||||
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
|
||||
validator.check("bias_hidden_shape", bhidden_shape,
|
||||
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(4)
|
||||
if h_shape is not None:
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(6)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
||||
whidden_shape[-1], Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
||||
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
if self.num_proj > 0:
|
||||
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
|
||||
else:
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
outh_shape = (num_step, batch_size, hidden_size)
|
||||
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
||||
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype,
|
||||
binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None):
|
||||
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class DynamicGRUV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DynamicGRUV2, self).__init__()
|
||||
self.dynamic_gru = P.DynamicGRUV2()
|
||||
|
||||
def construct(self, x, weight_i, weight_h, bias_i, bias_h, init_h):
|
||||
return self.dynamic_gru(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
|
||||
|
||||
|
||||
def test_dynamic_gru_v2():
|
||||
x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
||||
weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
|
||||
weight_h = Tensor(np.random.rand(16, 48).astype(np.float16))
|
||||
bias_i = Tensor(np.random.rand(48).astype(np.float16))
|
||||
bias_h = Tensor(np.random.rand(48).astype(np.float16))
|
||||
init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
|
||||
gru_net = DynamicGRUV2()
|
||||
output = gru_net(x, weight_i, weight_h, bias_i, bias_h, init_h)
|
||||
print(output)
|
|
@ -2532,11 +2532,7 @@ test_case_other_ops = [
|
|||
Tensor(np.random.rand(48).astype(np.float16)),
|
||||
Tensor(np.random.rand(48).astype(np.float16)),
|
||||
Tensor(np.random.rand(8, 16).astype(np.float16))],
|
||||
'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}),
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
test_case_quant_ops = [
|
||||
|
|
Loading…
Reference in New Issue