!45911 没有输入转属性后,SoftmaxGradFusion pass的pattern需要适配

Merge pull request !45911 from zuochuanyong/softmax_grad_fusion
This commit is contained in:
i-robot 2022-11-28 07:39:24 +00:00 committed by Gitee
commit 1660edda2a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 48 additions and 27 deletions

View File

@ -19,7 +19,7 @@
#include <set>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "ir/primitive.h"
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/utils.h"
namespace mindspore {
@ -37,26 +37,29 @@ bool NeedFusion(const AnfNodePtr &reduce_sum, const AnfNodePtr &, const AnfNodeP
return false;
}
// check axis should be last dim
auto prim = common::AnfAlgo::GetCNodePrimitive(reduce_sum_node);
MS_EXCEPTION_IF_NULL(prim);
if (!prim->HasAttr(kAttrAxis)) {
MS_LOG(INFO) << "ReduceSum should have attr axis if do fusion.";
// check axis of ReduceSum should be the last dim
auto axis_node = common::AnfAlgo::GetInputNode(reduce_sum_node, 1);
MS_EXCEPTION_IF_NULL(axis_node);
if (!axis_node->isa<ValueNode>()) {
MS_LOG(INFO) << "ReduceSum's axis input should be an ValueNode if do fusion.";
return false;
}
int64_t axis;
auto axis_value = prim->GetAttr(kAttrAxis);
if (axis_value->isa<Int64Imm>()) {
axis = GetValue<int64_t>(axis_value);
} else if (axis_value->isa<ValueTuple>()) {
auto axis_tuple = GetValue<std::vector<int64_t>>(axis_value);
if (axis_tuple.size() != 1) {
MS_LOG(INFO) << "The size of ReduceSum's attr axis should be 1 if do fusion.";
int64_t axis = 1;
auto axis_node_ptr = axis_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(axis_node_ptr);
auto value_node = GetValueNode(axis_node_ptr);
MS_EXCEPTION_IF_NULL(value_node);
if (value_node->isa<tensor::Tensor>()) {
auto const_tensor = value_node->cast<tensor::TensorPtr>();
std::vector<int64_t> axis_vector = TensorValueToVector<int64_t>(const_tensor);
if (axis_vector.size() != 1) {
MS_LOG(INFO) << "The size of ReduceSum's axis input should be 1 if do fusion, but got " << axis_vector.size();
return false;
}
axis = axis_tuple[0];
axis = axis_vector[0];
} else {
MS_LOG(INFO) << "ReduceSum's axis input should be a Tensor if do fusion.";
return false;
}
@ -85,7 +88,7 @@ bool NeedFusion(const AnfNodePtr &reduce_sum, const AnfNodePtr &, const AnfNodeP
const BaseRef SoftmaxGradFusionCpu::DefinePattern() const {
// pattern: mul(out, sub(dout, reduce_sum(mul(out, dout), -1)))
VectorRef mul = VectorRef({prim::kPrimMul, input0_, input1_});
VectorRef reduce_sum = VectorRef({reduce_sum_, mul});
VectorRef reduce_sum = VectorRef({reduce_sum_, mul, axis_});
VectorRef sub = VectorRef({prim::kPrimSub, input1_, reduce_sum});
VectorRef pattern = VectorRef({prim::kPrimMul, input0_, sub});
return pattern;

View File

@ -27,6 +27,7 @@ class SoftmaxGradFusionCpu : public PatternProcessPass {
reduce_sum_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimReduceSum->name()));
input0_ = std::make_shared<Var>();
input1_ = std::make_shared<Var>();
axis_ = std::make_shared<Var>();
}
~SoftmaxGradFusionCpu() override = default;
const BaseRef DefinePattern() const override;
@ -36,6 +37,7 @@ class SoftmaxGradFusionCpu : public PatternProcessPass {
VarPtr reduce_sum_;
VarPtr input0_;
VarPtr input1_;
VarPtr axis_;
};
} // namespace opt
} // namespace mindspore

View File

@ -92,14 +92,20 @@ std::shared_ptr<session::KernelGraph> BackendCommon::GetKernelGraph(const FuncGr
auto kernel_graph = session->ConstructKernelGraph(applies, outs);
kernel_graph->SetExecOrderByDefault();
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
optimizer->AddPassManager(pm);
optimizer->Optimize(kernel_graph);
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
auto device = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device == kAscendDevice) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
optimizer->AddPassManager(pm);
optimizer->Optimize(kernel_graph);
MS_LOG(INFO) << "New Kernel Graph infos:";
PrintGraphNodeList(kernel_graph);
}
MS_LOG(INFO) << "New Kernel Graph infos:";
PrintGraphNodeList(kernel_graph);
return kernel_graph;
}

View File

@ -23,9 +23,19 @@ namespace mindspore {
namespace opt {
class TestSoftmaxGradFusionCpu : public BackendCommon {
public:
TestSoftmaxGradFusionCpu() : get_py_fun_("gtest_input.pre_activate.softmax_grad_fusion_cpu", true) {}
~TestSoftmaxGradFusionCpu() override = default;
TestSoftmaxGradFusionCpu() : get_py_fun_("gtest_input.pre_activate.softmax_grad_fusion_cpu", true) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
orig_deivce_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kCPUDevice);
}
~TestSoftmaxGradFusionCpu() override {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, orig_deivce_);
}
std::string orig_deivce_;
UT::PyFuncGraphFetcher get_py_fun_;
};
@ -50,8 +60,8 @@ TEST_F(TestSoftmaxGradFusionCpu, test_softmax_grad_fusion_cpu) {
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_fusion_cpu", "after");
// TODO(zuochuanyong) In cpu context, reduce sum not input to attr
// EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore