fix dropout domask format select

This commit is contained in:
reku1997 2022-11-30 10:22:27 +08:00
parent 6733c6a529
commit d84c34c795
8 changed files with 278 additions and 129 deletions

View File

@ -241,8 +241,10 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
std::vector<NodeDebugInfoPtr> primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources);
if (cnode != nullptr) {
primal_attrs = cnode->primal_attrs();
cnode->AddPrimalAttr(kPrimalAttrUniqueId, MakeValue(cnode->UniqueId()));
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
primal_attrs[kPrimalAttrForwardUniqueId] = MakeValue(cnode->UniqueId());
}
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
if (expanded_fg == nullptr) {

View File

@ -342,8 +342,13 @@ CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &orig
++recompute_id;
recomputed_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id));
origin_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id));
if (IsPrimitiveCNode(origin_node, prim::kPrimDropout) && origin_node->HasPrimalAttr(kAttrFusion)) {
recomputed_node->AddPrimalAttr(kAttrFusion, origin_node->GetPrimalAttr(kAttrFusion));
static const PrimitiveSet dropout_prims = {prim::kPrimDropout, prim::kPrimDropoutDoMask, prim::kPrimDropoutDoMaskV3};
static const std::vector<std::string> need_primal_attr = {kAttrFusion, kPrimalAttrUniqueId,
kPrimalAttrForwardUniqueId};
if (IsOneOfPrimitiveCNode(origin_node, dropout_prims)) {
for (auto &primal_attr : need_primal_attr) {
recomputed_node->AddPrimalAttr(primal_attr, origin_node->GetPrimalAttr(primal_attr));
}
}
return recomputed_node;
}

View File

@ -921,6 +921,8 @@ constexpr auto kCustomTypeHybrid = "hybrid";
// primal attr key name
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
constexpr auto kPrimalAttrUniqueId = "unique_id";
constexpr auto kPrimalAttrForwardUniqueId = "forward_unique_id";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,126 +24,45 @@
namespace mindspore {
namespace opt {
namespace {
// used to store the cross subgraph domask format
std::map<std::string, std::string> rectify_do_mask_map;
void RectifyKernelInfoInPynativeProcess(const CNodePtr &cnode) {
auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
if (cnode_name != prim::kPrimDropOutDoMask->name() && cnode_name != prim::kPrimDropOutDoMaskV3->name()) {
return;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(cnode, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(cnode));
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
}
}
bool IsDoMask(const BaseRef &ref) {
if (utils::isa<AnfNodePtr>(ref)) {
auto node = utils::cast<AnfNodePtr>(ref);
MS_EXCEPTION_IF_NULL(node);
if (IsPrimitive(node, prim::kPrimDropOutDoMask) || IsPrimitive(node, prim::kPrimDropOutDoMaskV3)) {
return true;
}
}
return false;
}
} // namespace
const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr X = std::make_shared<CondVar>(IsDoMask);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({X, Xs});
}
const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return RectifyKernelInfoInPynativeProcess(node);
}
auto node_name = common::AnfAlgo::GetCNodeName(cnode);
if (node_name != prim::kDropoutGenMask || node_name != prim::kStatelessDropOutGenMask) {
return nullptr;
}
std::vector<CNodePtr> do_mask_node_list;
auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode);
MS_EXCEPTION_IF_NULL(gen_mask_output_nodes);
for (const auto &output_node : *gen_mask_output_nodes) {
if (common::AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropOutDoMask->name()) {
MS_EXCEPTION_IF_NULL(output_node.first);
auto output_cnode = output_node.first->cast<CNodePtr>();
do_mask_node_list.push_back(output_cnode);
}
}
RectifyKernelInfo(do_mask_node_list, graph);
return nullptr;
}
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const FuncGraphPtr &graph) const {
std::map<std::string, size_t> format_counter;
std::string special_format;
std::string convert_format;
for (const auto &do_mask : do_mask_node_list) {
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0);
if (special_format.empty() && IsOneOfHWSpecialFormat(do_mask_data_format)) {
special_format = do_mask_data_format;
}
if (format_counter.find(do_mask_data_format) == format_counter.end()) {
format_counter[do_mask_data_format] = 1;
} else {
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1;
}
}
if (format_counter.size() == 1) {
return;
}
if (convert_format.empty()) {
convert_format = GetConvertFormat(format_counter);
}
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph);
}
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
constexpr size_t kFormatCount = 2;
std::string convert_format = kOpFormat_DEFAULT;
size_t counter = 0;
if (format_counter.size() > kFormatCount) {
return kOpFormat_DEFAULT;
}
if (format_counter.size() == kFormatCount && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) {
return kOpFormat_DEFAULT;
}
for (const auto &iter : format_counter) {
if (counter < iter.second) {
convert_format = iter.first;
counter = iter.second;
} else if (counter == iter.second && IsOneOfHWSpecialFormat(iter.first)) {
convert_format = iter.first;
}
}
return convert_format;
}
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const std::string &format,
const FuncGraphPtr &graph) const {
for (const auto &do_mask : do_mask_node_list) {
if (AnfAlgo::GetInputFormat(do_mask, 0) != format) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask));
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputFormat(format, 0);
builder->SetOutputFormat(format, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get());
ReSelecChildNodeKernelInfo(do_mask, graph);
}
}
}
AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
if (common::AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropOutDoMask->name()) {
return nullptr;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
return nullptr;
}
void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const {
void RectifyDoMaskKernelInfo::ReSelectChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(cnode);
auto output_node_list = GetRealNodeUsedList(graph, cnode);
MS_EXCEPTION_IF_NULL(output_node_list);
@ -157,15 +76,58 @@ void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode,
MS_EXCEPTION_IF_NULL(new_build_info);
MS_EXCEPTION_IF_NULL(ori_build_info);
if ((*new_build_info) != (*ori_build_info)) {
ReSelecChildNodeKernelInfo(out_node, graph);
ReSelectChildNodeKernelInfo(out_node, graph);
}
} else if (common::AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() ||
common::AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) {
ReSelecChildNodeKernelInfo(out_node, graph);
ReSelectChildNodeKernelInfo(out_node, graph);
} else {
MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed";
}
}
}
const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto do_mask_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(do_mask_node);
if (!do_mask_node->HasPrimalAttr(kPrimalAttrUniqueId) && !do_mask_node->HasPrimalAttr(kPrimalAttrForwardUniqueId)) {
MS_LOG(WARNING) << "The DoMask cnode has no primal attr: " << do_mask_node->DebugString();
return nullptr;
}
MS_EXCEPTION_IF_NULL(do_mask_node);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
RectifyKernelInfoInPynativeProcess(do_mask_node);
return nullptr;
}
std::string unique_id;
if (do_mask_node->HasPrimalAttr(kPrimalAttrUniqueId)) {
unique_id = do_mask_node->GetPrimalAttr(kPrimalAttrUniqueId)->DumpText();
} else if (do_mask_node->HasPrimalAttr(kPrimalAttrForwardUniqueId)) {
unique_id = do_mask_node->GetPrimalAttr(kPrimalAttrForwardUniqueId)->DumpText();
}
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask_node, 0);
auto iter = rectify_do_mask_map.find(unique_id);
if (iter == rectify_do_mask_map.end()) {
rectify_do_mask_map.insert({unique_id, do_mask_data_format});
} else {
auto &format = iter->second;
MS_LOG(INFO) << "Node: " << do_mask_node->DebugString() << " need convert origin format: " << do_mask_data_format
<< " to new format: " << format;
if (format != do_mask_data_format) {
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(
AnfAlgo::GetSelectKernelBuildInfo(do_mask_node));
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputFormat(format, 0);
builder->SetOutputFormat(format, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask_node.get());
ReSelectChildNodeKernelInfo(do_mask_node, do_mask_node->func_graph());
}
}
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@
#include "backend/common/optimizer/optimizer.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
class RectifyDoMaskKernelInfo : public PatternProcessPass {
@ -31,16 +32,12 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
: PatternProcessPass("rectify_do_mask_kernel_info", multigraph),
kernel_selecter(std::make_shared<KernelSelect>()) {}
~RectifyDoMaskKernelInfo() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
private:
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const FuncGraphPtr &graph) const;
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format,
const FuncGraphPtr &graph) const;
void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const;
void ReSelectChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const;
KernelSelectPtr kernel_selecter;
};
} // namespace opt

View File

@ -238,8 +238,12 @@ CNodePtr CreateDropoutDoMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr
dropout_do_mask->set_abstract(abstract);
dropout_do_mask->set_scope(dropout->scope());
common::AnfAlgo::CopyNodeAttr(kAttrKeepProb, dropout, dropout_do_mask);
if (dropout->HasPrimalAttr(kAttrMicro)) {
dropout_do_mask->AddPrimalAttr(kAttrMicro, dropout->GetPrimalAttr(kAttrMicro));
std::vector<std::string> need_primal_attr = {kAttrMicro, kPrimalAttrUniqueId, kPrimalAttrForwardUniqueId};
for (auto &primal_attr : need_primal_attr) {
if (dropout->HasPrimalAttr(primal_attr)) {
dropout_do_mask->AddPrimalAttr(primal_attr, dropout->GetPrimalAttr(primal_attr));
}
}
return dropout_do_mask;
}

View File

@ -0,0 +1,129 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "utils/log_adapter.h"
#include "ir/graph_utils.h"
#include "pipeline/jit/resource.h"
#include "include/common/debug/draw.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/optimizer.h"
#include "utils/ms_context.h"
#include "plugin/device/ascend/optimizer/mindir/ascend_vm_op_adapter.h"
#define private public
#define protected public
#include "plugin/device/ascend/optimizer/format_type/rectify_do_mask_kernel_info.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
class TestHWRectifyDoMask : public BackendCommon {
public:
TestHWRectifyDoMask() : getPyFun_("gtest_input.pre_activate.rectify_do_mask_test", true) {}
~TestHWRectifyDoMask() override = default;
UT::PyFuncGraphFetcher getPyFun_;
void SetFormat(const FuncGraphPtr &g) {
auto node_list = TopoSort(g->get_return());
int cnt = 0;
for (auto &node : node_list) {
if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
auto cnode = node->cast<CNodePtr>();
cnt ++;
KernelBuildInfoBuilder builder;
if (cnt == 1) {
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetOutputsFormat({kOpFormat_DEFAULT});
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
cnode->AddPrimalAttr(kPrimalAttrUniqueId, MakeValue("1"));
} else {
builder.SetInputsFormat({kOpFormat_FRAC_NZ, kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetOutputsFormat({kOpFormat_FRAC_NZ});
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
cnode->AddPrimalAttr(kPrimalAttrForwardUniqueId, MakeValue("1"));
}
}
}
}
void CheckNQ(const FuncGraphPtr &g) {
std::string format;
auto node_list = TopoSort(g->get_return());
for (auto &node : node_list) {
if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
if (format.empty()) {
format = AnfAlgo::GetInputFormat(node, 0);
} else {
EXPECT_NE(format, AnfAlgo::GetInputFormat(node, 0));
}
}
}
ASSERT_FALSE(format.empty());
}
void CheckEQ(const FuncGraphPtr &g) {
std::string format;
auto node_list = TopoSort(g->get_return());
for (auto &node : node_list) {
if (IsPrimitiveCNode(node, prim::kPrimDropOutDoMask)) {
if (format.empty()) {
format = AnfAlgo::GetInputFormat(node, 0);
} else {
EXPECT_EQ(format, AnfAlgo::GetInputFormat(node, 0));
}
}
}
ASSERT_FALSE(format.empty());
}
};
/// Feature: Test RectifyDoMaskKernelInfo pass
/// Description: Test RectifyDoMaskKernelInfo pass
/// Expectation: The forward and backward DropOutDoMask should select same format.
TEST_F(TestHWRectifyDoMask, test_rectify_dropout_do_mask) {
/*
* def test_rectify_dropout_do_mask(x):
* mask = gen_mask(shape, 0.9)
* res_x = do_mask(x, mask, 0.9)
* res_y = do_mask(y, mask, 0.9)
* return res_x, res_y
*
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr fg = getPyFun_.CallAndParseRet("test_rectify_dropout_do_mask", "f");
SetFormat(fg);
CheckNQ(fg);
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
auto pass_manager = std::make_shared<opt::PassManager>();
pass_manager->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
pass_manager->AddPass(std::make_shared<opt::RectifyDoMaskKernelInfo>());
graph_optimizer->AddPassManager(pass_manager);
auto new_g = graph_optimizer->Optimize(fg);
CheckEQ(new_g);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,48 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
shape = (2, 4, 2, 2)
gen_mask = P.DropoutGenMask(10, 10)
do_mask = P.DropoutDoMask()
class FnDict:
def __init__(self):
self.fn_dict = {}
def __call__(self, fn):
self.fn_dict[fn.__name__] = fn
def __getitem__(self, name):
return self.fn_dict.get(name)
def test_rectify_dropout_do_mask(tag):
"""
Feature: Test RectifyDoMaskKernelInfo pass
Description: Test RectifyDoMaskKernelInfo pass
Expectation: The forward and backward DropOutDoMask should select same format.
"""
fns = FnDict()
@fns
def f(x, y):
mask = gen_mask(shape, 0.9)
res_x = do_mask(x, mask, 0.9)
res_y = do_mask(y, mask, 0.9)
return res_x, res_y
return fns[tag]