!13155 [ME]Fix bug of embed J and side-by-side J

From: @chenfei52
Reviewed-by: @zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-15 15:35:36 +08:00 committed by Gitee
commit b8b96e15e7
8 changed files with 65 additions and 65 deletions

View File

@ -123,8 +123,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
// Gradient transforms
expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
// branch culling

View File

@ -85,7 +85,6 @@ class OptimizeIRPassLib {
SubstitutionPtr accumulaten_eliminater_;
// Gradient irpasses
SubstitutionPtr expand_jprim_;
SubstitutionPtr minmaximum_grad_;
// inline

View File

@ -22,25 +22,28 @@ namespace irpass {
namespace internal {
AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
ScopeGuard scope_guard(vnode->scope());
auto newg = ad::Kprim(vnode, resource);
if (newg != nullptr) {
return NewValueNode(newg);
}
// when find in J failed, try in Jmeta
auto prim = GetValueNode<PrimitivePtr>(vnode);
MetaFuncGraphPtr meta = ad::Kmeta(prim, resource);
if (meta != nullptr) {
return NewValueNode(meta);
}
return nullptr;
}
bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) {
// if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph.
// ExpandJ innermost graph first.
bool CheckIfEmbedJ(const CNodePtr &j_node) {
auto &value_node = j_node->input(1);
if (IsValueNode<Primitive>(value_node)) {
return false;
}
auto func_graph = GetValueNode<FuncGraphPtr>(value_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Unexpected j node:" << j_node->DebugString();
}
auto func_graph_manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(func_graph_manager);
return func_graph_manager->func_graph_j_total(func_graph);
@ -49,31 +52,48 @@ bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) {
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
if (IsValueNode<FuncGraph>(vnode)) {
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString();
// high_order_grad begin;
// if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
// ExpandJ innermost graph or primitive first.
if (CheckIfEmbedJ(func_graph)) {
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later";
return nullptr;
}
// high_order_grad end;
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now";
auto newfg = ad::Grad(func_graph, resource);
return NewValueNode(newfg);
}
if (IsValueNode<Primitive>(vnode)) {
return ExpandJPrimitive(vnode, resource);
}
return nullptr;
}
} // namespace internal
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
// Search all j nodes.
GetJPrim(optimizer->resource()->manager());
// Get j nodes that don't have embed j nodes.
std::vector<CNodePtr> todo;
// If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
// ExpandJ innermost graph or primitive first.
std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo),
[](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); });
// Expand j nodes that don't have embed j nodes.
bool change = false;
for (auto &j_node : todo) {
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource());
optimizer->resource()->manager()->Replace(j_node, expanded_j);
change = true;
}
return change;
}
void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) {
j_nodes_.clear();
for (auto &fg : manager->func_graphs()) {
std::vector<AnfNodePtr> &&toposet = TopoSort(fg->get_return());
for (const auto &node : toposet) {
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
j_nodes_.push_back(node->cast<CNodePtr>());
}
}
}
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -31,28 +31,17 @@
namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource);
} // namespace internal
// {prim::kPrimJ, C}
class ExpandJPrim : public AnfVisitor {
class ExpandJPrim {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
x_ = nullptr;
AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node);
if (x_ != nullptr) {
TraceGuard guard(std::make_shared<TraceExpandJ>(node->debug_info()));
auto j_node = internal::ExpandJ(x_, optimizer->resource());
return j_node;
}
return nullptr;
}
void Visit(const ValueNodePtr &node) override { x_ = node; }
ExpandJPrim() = default;
virtual ~ExpandJPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
void GetJPrim(const FuncGraphManagerPtr &manager);
private:
ValueNodePtr x_{nullptr};
std::vector<CNodePtr> j_nodes_;
};
} // namespace irpass
} // namespace opt

View File

@ -41,6 +41,7 @@
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#include "ps/ps_context.h"
@ -166,7 +167,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.mini_step_allgather_replace_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
opt::irpass::ResolveIRPassLib resolve_irpass;
opt::OptPassConfig resolve_pass =
@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
{"virtual_dataset", virtual_dataset},
{"grad", grad},
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
{"resolve", resolve_pass},
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()},

View File

@ -88,14 +88,6 @@ class TestOptLib : public UT::Common {
irpass::OptimizeIRPassLib irpass;
};
TEST_F(TestOptLib, test_expendJ) {
FuncGraphPtr before = getPyFun("test_expendJ");
ASSERT_TRUE(nullptr != before);
FuncGraphPtr after = RunSubs(before, std::vector<SubstitutionPtr>({irpass.expand_jprim_}));
}
TEST_F(TestOptLib, test_simplify_always_true_false) {
FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1");
FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2");

View File

@ -24,6 +24,7 @@
#include "frontend/optimizer/cse_pass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "debug/draw.h"
namespace mindspore {
@ -38,23 +39,24 @@ class TestOptOptimizer : public UT::Common {
};
TEST_F(TestOptOptimizer, test_step_opt) {
FuncGraphPtr before = getPyFun("test_expendJ");
FuncGraphPtr before = getPyFun("test_expandJ");
ASSERT_TRUE(nullptr != before);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::shared_ptr<Optimizer> optimizer = Optimizer::MakeOptimizer("ut_test", res,
{{"main",
{
// Branch culling
irpass.switch_simplify_,
std::shared_ptr<Optimizer> optimizer =
Optimizer::MakeOptimizer("ut_test", res,
{{"main",
{
// Branch culling
irpass.switch_simplify_,
// Safe inlining
irpass.arithmetic_simplify_,
irpass.inline_,
}},
{"grad", {irpass.expand_jprim_}},
{"cse", OptPassConfig(CSEPass(false))}},
true);
// Safe inlining
irpass.arithmetic_simplify_,
irpass.inline_,
}},
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
{"cse", OptPassConfig(CSEPass(false))}},
true);
EXPECT_TRUE(optimizer.get() != nullptr);
auto after = optimizer->step(before);

View File

@ -133,8 +133,8 @@ def cost(x):
J = Primitive('J')
def test_expendJ(x):
""" test_expendJ """
def test_expandJ(x):
""" test_expandJ """
return J(cost)(x)