forked from mindspore-Ecosystem/mindspore
!45911 没有输入转属性后,SoftmaxGradFusion pass的pattern需要适配
Merge pull request !45911 from zuochuanyong/softmax_grad_fusion
This commit is contained in:
commit
1660edda2a
|
@ -19,7 +19,7 @@
|
|||
#include <set>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,26 +37,29 @@ bool NeedFusion(const AnfNodePtr &reduce_sum, const AnfNodePtr &, const AnfNodeP
|
|||
return false;
|
||||
}
|
||||
|
||||
// check axis should be last dim
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(reduce_sum_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (!prim->HasAttr(kAttrAxis)) {
|
||||
MS_LOG(INFO) << "ReduceSum should have attr axis if do fusion.";
|
||||
// check axis of ReduceSum should be the last dim
|
||||
auto axis_node = common::AnfAlgo::GetInputNode(reduce_sum_node, 1);
|
||||
MS_EXCEPTION_IF_NULL(axis_node);
|
||||
if (!axis_node->isa<ValueNode>()) {
|
||||
MS_LOG(INFO) << "ReduceSum's axis input should be an ValueNode if do fusion.";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t axis;
|
||||
auto axis_value = prim->GetAttr(kAttrAxis);
|
||||
if (axis_value->isa<Int64Imm>()) {
|
||||
axis = GetValue<int64_t>(axis_value);
|
||||
} else if (axis_value->isa<ValueTuple>()) {
|
||||
auto axis_tuple = GetValue<std::vector<int64_t>>(axis_value);
|
||||
if (axis_tuple.size() != 1) {
|
||||
MS_LOG(INFO) << "The size of ReduceSum's attr axis should be 1 if do fusion.";
|
||||
int64_t axis = 1;
|
||||
auto axis_node_ptr = axis_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_node_ptr);
|
||||
auto value_node = GetValueNode(axis_node_ptr);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
if (value_node->isa<tensor::Tensor>()) {
|
||||
auto const_tensor = value_node->cast<tensor::TensorPtr>();
|
||||
std::vector<int64_t> axis_vector = TensorValueToVector<int64_t>(const_tensor);
|
||||
if (axis_vector.size() != 1) {
|
||||
MS_LOG(INFO) << "The size of ReduceSum's axis input should be 1 if do fusion, but got " << axis_vector.size();
|
||||
return false;
|
||||
}
|
||||
axis = axis_tuple[0];
|
||||
axis = axis_vector[0];
|
||||
} else {
|
||||
MS_LOG(INFO) << "ReduceSum's axis input should be a Tensor if do fusion.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -85,7 +88,7 @@ bool NeedFusion(const AnfNodePtr &reduce_sum, const AnfNodePtr &, const AnfNodeP
|
|||
const BaseRef SoftmaxGradFusionCpu::DefinePattern() const {
|
||||
// pattern: mul(out, sub(dout, reduce_sum(mul(out, dout), -1)))
|
||||
VectorRef mul = VectorRef({prim::kPrimMul, input0_, input1_});
|
||||
VectorRef reduce_sum = VectorRef({reduce_sum_, mul});
|
||||
VectorRef reduce_sum = VectorRef({reduce_sum_, mul, axis_});
|
||||
VectorRef sub = VectorRef({prim::kPrimSub, input1_, reduce_sum});
|
||||
VectorRef pattern = VectorRef({prim::kPrimMul, input0_, sub});
|
||||
return pattern;
|
||||
|
|
|
@ -27,6 +27,7 @@ class SoftmaxGradFusionCpu : public PatternProcessPass {
|
|||
reduce_sum_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimReduceSum->name()));
|
||||
input0_ = std::make_shared<Var>();
|
||||
input1_ = std::make_shared<Var>();
|
||||
axis_ = std::make_shared<Var>();
|
||||
}
|
||||
~SoftmaxGradFusionCpu() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -36,6 +37,7 @@ class SoftmaxGradFusionCpu : public PatternProcessPass {
|
|||
VarPtr reduce_sum_;
|
||||
VarPtr input0_;
|
||||
VarPtr input1_;
|
||||
VarPtr axis_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,14 +92,20 @@ std::shared_ptr<session::KernelGraph> BackendCommon::GetKernelGraph(const FuncGr
|
|||
auto kernel_graph = session->ConstructKernelGraph(applies, outs);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(kernel_graph);
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto device = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device == kAscendDevice) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(kernel_graph);
|
||||
|
||||
MS_LOG(INFO) << "New Kernel Graph infos:";
|
||||
PrintGraphNodeList(kernel_graph);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "New Kernel Graph infos:";
|
||||
PrintGraphNodeList(kernel_graph);
|
||||
return kernel_graph;
|
||||
}
|
||||
|
||||
|
|
|
@ -23,9 +23,19 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class TestSoftmaxGradFusionCpu : public BackendCommon {
|
||||
public:
|
||||
TestSoftmaxGradFusionCpu() : get_py_fun_("gtest_input.pre_activate.softmax_grad_fusion_cpu", true) {}
|
||||
~TestSoftmaxGradFusionCpu() override = default;
|
||||
TestSoftmaxGradFusionCpu() : get_py_fun_("gtest_input.pre_activate.softmax_grad_fusion_cpu", true) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
orig_deivce_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kCPUDevice);
|
||||
}
|
||||
~TestSoftmaxGradFusionCpu() override {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, orig_deivce_);
|
||||
}
|
||||
|
||||
std::string orig_deivce_;
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
|
@ -50,8 +60,8 @@ TEST_F(TestSoftmaxGradFusionCpu, test_softmax_grad_fusion_cpu) {
|
|||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_fusion_cpu", "after");
|
||||
// TODO(zuochuanyong) In cpu context, reduce sum not input to attr
|
||||
// EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue