fix dropout domask format select
This commit is contained in:
parent
6733c6a529
commit
d84c34c795
|
@ -241,8 +241,10 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
std::vector<NodeDebugInfoPtr> primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources);
|
||||
if (cnode != nullptr) {
|
||||
primal_attrs = cnode->primal_attrs();
|
||||
cnode->AddPrimalAttr(kPrimalAttrUniqueId, MakeValue(cnode->UniqueId()));
|
||||
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
|
||||
primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
|
||||
primal_attrs[kPrimalAttrForwardUniqueId] = MakeValue(cnode->UniqueId());
|
||||
}
|
||||
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
|
||||
if (expanded_fg == nullptr) {
|
||||
|
|
|
@ -342,8 +342,13 @@ CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &orig
|
|||
++recompute_id;
|
||||
recomputed_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id));
|
||||
origin_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id));
|
||||
if (IsPrimitiveCNode(origin_node, prim::kPrimDropout) && origin_node->HasPrimalAttr(kAttrFusion)) {
|
||||
recomputed_node->AddPrimalAttr(kAttrFusion, origin_node->GetPrimalAttr(kAttrFusion));
|
||||
static const PrimitiveSet dropout_prims = {prim::kPrimDropout, prim::kPrimDropoutDoMask, prim::kPrimDropoutDoMaskV3};
|
||||
static const std::vector<std::string> need_primal_attr = {kAttrFusion, kPrimalAttrUniqueId,
|
||||
kPrimalAttrForwardUniqueId};
|
||||
if (IsOneOfPrimitiveCNode(origin_node, dropout_prims)) {
|
||||
for (auto &primal_attr : need_primal_attr) {
|
||||
recomputed_node->AddPrimalAttr(primal_attr, origin_node->GetPrimalAttr(primal_attr));
|
||||
}
|
||||
}
|
||||
return recomputed_node;
|
||||
}
|
||||
|
|
|
@ -921,6 +921,8 @@ constexpr auto kCustomTypeHybrid = "hybrid";
|
|||
|
||||
// primal attr key name
|
||||
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
|
||||
constexpr auto kPrimalAttrUniqueId = "unique_id";
|
||||
constexpr auto kPrimalAttrForwardUniqueId = "forward_unique_id";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,126 +24,45 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
// used to store the cross subgraph domask format
|
||||
std::map<std::string, std::string> rectify_do_mask_map;
|
||||
|
||||
void RectifyKernelInfoInPynativeProcess(const CNodePtr &cnode) {
|
||||
auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
if (cnode_name != prim::kPrimDropOutDoMask->name() && cnode_name != prim::kPrimDropOutDoMaskV3->name()) {
|
||||
return;
|
||||
}
|
||||
auto do_mask_input_format = AnfAlgo::GetInputFormat(cnode, 0);
|
||||
if (do_mask_input_format != kOpFormat_DEFAULT) {
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(cnode));
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
|
||||
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
|
||||
}
|
||||
}
|
||||
|
||||
bool IsDoMask(const BaseRef &ref) {
|
||||
if (utils::isa<AnfNodePtr>(ref)) {
|
||||
auto node = utils::cast<AnfNodePtr>(ref);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsPrimitive(node, prim::kPrimDropOutDoMask) || IsPrimitive(node, prim::kPrimDropOutDoMaskV3)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr X = std::make_shared<CondVar>(IsDoMask);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return RectifyKernelInfoInPynativeProcess(node);
|
||||
}
|
||||
auto node_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
if (node_name != prim::kDropoutGenMask || node_name != prim::kStatelessDropOutGenMask) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<CNodePtr> do_mask_node_list;
|
||||
auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_output_nodes);
|
||||
for (const auto &output_node : *gen_mask_output_nodes) {
|
||||
if (common::AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropOutDoMask->name()) {
|
||||
MS_EXCEPTION_IF_NULL(output_node.first);
|
||||
auto output_cnode = output_node.first->cast<CNodePtr>();
|
||||
do_mask_node_list.push_back(output_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
RectifyKernelInfo(do_mask_node_list, graph);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
|
||||
const FuncGraphPtr &graph) const {
|
||||
std::map<std::string, size_t> format_counter;
|
||||
std::string special_format;
|
||||
std::string convert_format;
|
||||
for (const auto &do_mask : do_mask_node_list) {
|
||||
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0);
|
||||
if (special_format.empty() && IsOneOfHWSpecialFormat(do_mask_data_format)) {
|
||||
special_format = do_mask_data_format;
|
||||
}
|
||||
if (format_counter.find(do_mask_data_format) == format_counter.end()) {
|
||||
format_counter[do_mask_data_format] = 1;
|
||||
} else {
|
||||
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1;
|
||||
}
|
||||
}
|
||||
if (format_counter.size() == 1) {
|
||||
return;
|
||||
}
|
||||
if (convert_format.empty()) {
|
||||
convert_format = GetConvertFormat(format_counter);
|
||||
}
|
||||
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph);
|
||||
}
|
||||
|
||||
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
|
||||
constexpr size_t kFormatCount = 2;
|
||||
std::string convert_format = kOpFormat_DEFAULT;
|
||||
size_t counter = 0;
|
||||
if (format_counter.size() > kFormatCount) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
if (format_counter.size() == kFormatCount && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
for (const auto &iter : format_counter) {
|
||||
if (counter < iter.second) {
|
||||
convert_format = iter.first;
|
||||
counter = iter.second;
|
||||
} else if (counter == iter.second && IsOneOfHWSpecialFormat(iter.first)) {
|
||||
convert_format = iter.first;
|
||||
}
|
||||
}
|
||||
return convert_format;
|
||||
}
|
||||
|
||||
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
|
||||
const std::string &format,
|
||||
const FuncGraphPtr &graph) const {
|
||||
for (const auto &do_mask : do_mask_node_list) {
|
||||
if (AnfAlgo::GetInputFormat(do_mask, 0) != format) {
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask));
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputFormat(format, 0);
|
||||
builder->SetOutputFormat(format, 0);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get());
|
||||
ReSelecChildNodeKernelInfo(do_mask, graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropOutDoMask->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
if (do_mask_input_format != kOpFormat_DEFAULT) {
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
|
||||
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const {
|
||||
void RectifyDoMaskKernelInfo::ReSelectChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto output_node_list = GetRealNodeUsedList(graph, cnode);
|
||||
MS_EXCEPTION_IF_NULL(output_node_list);
|
||||
|
@ -157,15 +76,58 @@ void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode,
|
|||
MS_EXCEPTION_IF_NULL(new_build_info);
|
||||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||
if ((*new_build_info) != (*ori_build_info)) {
|
||||
ReSelecChildNodeKernelInfo(out_node, graph);
|
||||
ReSelectChildNodeKernelInfo(out_node, graph);
|
||||
}
|
||||
} else if (common::AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() ||
|
||||
common::AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) {
|
||||
ReSelecChildNodeKernelInfo(out_node, graph);
|
||||
ReSelectChildNodeKernelInfo(out_node, graph);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto do_mask_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(do_mask_node);
|
||||
if (!do_mask_node->HasPrimalAttr(kPrimalAttrUniqueId) && !do_mask_node->HasPrimalAttr(kPrimalAttrForwardUniqueId)) {
|
||||
MS_LOG(WARNING) << "The DoMask cnode has no primal attr: " << do_mask_node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(do_mask_node);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
RectifyKernelInfoInPynativeProcess(do_mask_node);
|
||||
return nullptr;
|
||||
}
|
||||
std::string unique_id;
|
||||
if (do_mask_node->HasPrimalAttr(kPrimalAttrUniqueId)) {
|
||||
unique_id = do_mask_node->GetPrimalAttr(kPrimalAttrUniqueId)->DumpText();
|
||||
} else if (do_mask_node->HasPrimalAttr(kPrimalAttrForwardUniqueId)) {
|
||||
unique_id = do_mask_node->GetPrimalAttr(kPrimalAttrForwardUniqueId)->DumpText();
|
||||
}
|
||||
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask_node, 0);
|
||||
auto iter = rectify_do_mask_map.find(unique_id);
|
||||
if (iter == rectify_do_mask_map.end()) {
|
||||
rectify_do_mask_map.insert({unique_id, do_mask_data_format});
|
||||
} else {
|
||||
auto &format = iter->second;
|
||||
MS_LOG(INFO) << "Node: " << do_mask_node->DebugString() << " need convert origin format: " << do_mask_data_format
|
||||
<< " to new format: " << format;
|
||||
if (format != do_mask_data_format) {
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(
|
||||
AnfAlgo::GetSelectKernelBuildInfo(do_mask_node));
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputFormat(format, 0);
|
||||
builder->SetOutputFormat(format, 0);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask_node.get());
|
||||
ReSelectChildNodeKernelInfo(do_mask_node, do_mask_node->func_graph());
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class RectifyDoMaskKernelInfo : public PatternProcessPass {
|
||||
|
@ -31,16 +32,12 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
|
|||
: PatternProcessPass("rectify_do_mask_kernel_info", multigraph),
|
||||
kernel_selecter(std::make_shared<KernelSelect>()) {}
|
||||
~RectifyDoMaskKernelInfo() override = default;
|
||||
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const FuncGraphPtr &graph) const;
|
||||
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
|
||||
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
|
||||
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format,
|
||||
const FuncGraphPtr &graph) const;
|
||||
void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const;
|
||||
void ReSelectChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const;
|
||||
KernelSelectPtr kernel_selecter;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -238,8 +238,12 @@ CNodePtr CreateDropoutDoMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
dropout_do_mask->set_abstract(abstract);
|
||||
dropout_do_mask->set_scope(dropout->scope());
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrKeepProb, dropout, dropout_do_mask);
|
||||
if (dropout->HasPrimalAttr(kAttrMicro)) {
|
||||
dropout_do_mask->AddPrimalAttr(kAttrMicro, dropout->GetPrimalAttr(kAttrMicro));
|
||||
|
||||
std::vector<std::string> need_primal_attr = {kAttrMicro, kPrimalAttrUniqueId, kPrimalAttrForwardUniqueId};
|
||||
for (auto &primal_attr : need_primal_attr) {
|
||||
if (dropout->HasPrimalAttr(primal_attr)) {
|
||||
dropout_do_mask->AddPrimalAttr(primal_attr, dropout->GetPrimalAttr(primal_attr));
|
||||
}
|
||||
}
|
||||
return dropout_do_mask;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "include/common/debug/draw.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/ascend_vm_op_adapter.h"
|
||||
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "plugin/device/ascend/optimizer/format_type/rectify_do_mask_kernel_info.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
|
||||
class TestHWRectifyDoMask : public BackendCommon {
|
||||
public:
|
||||
TestHWRectifyDoMask() : getPyFun_("gtest_input.pre_activate.rectify_do_mask_test", true) {}
|
||||
~TestHWRectifyDoMask() override = default;
|
||||
UT::PyFuncGraphFetcher getPyFun_;
|
||||
|
||||
void SetFormat(const FuncGraphPtr &g) {
|
||||
auto node_list = TopoSort(g->get_return());
|
||||
int cnt = 0;
|
||||
for (auto &node : node_list) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
cnt ++;
|
||||
KernelBuildInfoBuilder builder;
|
||||
if (cnt == 1) {
|
||||
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
|
||||
cnode->AddPrimalAttr(kPrimalAttrUniqueId, MakeValue("1"));
|
||||
} else {
|
||||
builder.SetInputsFormat({kOpFormat_FRAC_NZ, kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetOutputsFormat({kOpFormat_FRAC_NZ});
|
||||
node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
|
||||
cnode->AddPrimalAttr(kPrimalAttrForwardUniqueId, MakeValue("1"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckNQ(const FuncGraphPtr &g) {
|
||||
std::string format;
|
||||
auto node_list = TopoSort(g->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
|
||||
if (format.empty()) {
|
||||
format = AnfAlgo::GetInputFormat(node, 0);
|
||||
} else {
|
||||
EXPECT_NE(format, AnfAlgo::GetInputFormat(node, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_FALSE(format.empty());
|
||||
}
|
||||
|
||||
void CheckEQ(const FuncGraphPtr &g) {
|
||||
std::string format;
|
||||
auto node_list = TopoSort(g->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDropOutDoMask)) {
|
||||
if (format.empty()) {
|
||||
format = AnfAlgo::GetInputFormat(node, 0);
|
||||
} else {
|
||||
EXPECT_EQ(format, AnfAlgo::GetInputFormat(node, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_FALSE(format.empty());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Feature: Test RectifyDoMaskKernelInfo pass
|
||||
/// Description: Test RectifyDoMaskKernelInfo pass
|
||||
/// Expectation: The forward and backward DropOutDoMask should select same format.
|
||||
TEST_F(TestHWRectifyDoMask, test_rectify_dropout_do_mask) {
|
||||
/*
|
||||
* def test_rectify_dropout_do_mask(x):
|
||||
* mask = gen_mask(shape, 0.9)
|
||||
* res_x = do_mask(x, mask, 0.9)
|
||||
* res_y = do_mask(y, mask, 0.9)
|
||||
* return res_x, res_y
|
||||
*
|
||||
*/
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
FuncGraphPtr fg = getPyFun_.CallAndParseRet("test_rectify_dropout_do_mask", "f");
|
||||
SetFormat(fg);
|
||||
CheckNQ(fg);
|
||||
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pass_manager = std::make_shared<opt::PassManager>();
|
||||
pass_manager->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
|
||||
pass_manager->AddPass(std::make_shared<opt::RectifyDoMaskKernelInfo>());
|
||||
graph_optimizer->AddPassManager(pass_manager);
|
||||
auto new_g = graph_optimizer->Optimize(fg);
|
||||
CheckEQ(new_g);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
shape = (2, 4, 2, 2)
|
||||
gen_mask = P.DropoutGenMask(10, 10)
|
||||
do_mask = P.DropoutDoMask()
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fn_dict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fn_dict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fn_dict.get(name)
|
||||
|
||||
|
||||
def test_rectify_dropout_do_mask(tag):
|
||||
"""
|
||||
Feature: Test RectifyDoMaskKernelInfo pass
|
||||
Description: Test RectifyDoMaskKernelInfo pass
|
||||
Expectation: The forward and backward DropOutDoMask should select same format.
|
||||
"""
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def f(x, y):
|
||||
mask = gen_mask(shape, 0.9)
|
||||
res_x = do_mask(x, mask, 0.9)
|
||||
res_y = do_mask(y, mask, 0.9)
|
||||
return res_x, res_y
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue