forked from mindspore-Ecosystem/mindspore
!13155 [ME]Fix bug of embed J and side-by-side J
From: @chenfei52 Reviewed-by: @zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
b8b96e15e7
|
@ -123,8 +123,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
|
||||
IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
|
||||
// Gradient transforms
|
||||
expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
|
||||
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
|
||||
|
||||
// branch culling
|
||||
|
|
|
@ -85,7 +85,6 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr accumulaten_eliminater_;
|
||||
|
||||
// Gradient irpasses
|
||||
SubstitutionPtr expand_jprim_;
|
||||
SubstitutionPtr minmaximum_grad_;
|
||||
|
||||
// inline
|
||||
|
|
|
@ -22,25 +22,28 @@ namespace irpass {
|
|||
namespace internal {
|
||||
AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
|
||||
ScopeGuard scope_guard(vnode->scope());
|
||||
|
||||
auto newg = ad::Kprim(vnode, resource);
|
||||
if (newg != nullptr) {
|
||||
return NewValueNode(newg);
|
||||
}
|
||||
|
||||
// when find in J failed, try in Jmeta
|
||||
auto prim = GetValueNode<PrimitivePtr>(vnode);
|
||||
MetaFuncGraphPtr meta = ad::Kmeta(prim, resource);
|
||||
if (meta != nullptr) {
|
||||
return NewValueNode(meta);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) {
|
||||
// if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph.
|
||||
// ExpandJ innermost graph first.
|
||||
bool CheckIfEmbedJ(const CNodePtr &j_node) {
|
||||
auto &value_node = j_node->input(1);
|
||||
if (IsValueNode<Primitive>(value_node)) {
|
||||
return false;
|
||||
}
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(value_node);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected j node:" << j_node->DebugString();
|
||||
}
|
||||
auto func_graph_manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(func_graph_manager);
|
||||
return func_graph_manager->func_graph_j_total(func_graph);
|
||||
|
@ -49,31 +52,48 @@ bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) {
|
|||
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
|
||||
if (IsValueNode<FuncGraph>(vnode)) {
|
||||
ScopeGuard scope_guard(vnode->scope());
|
||||
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
|
||||
MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString();
|
||||
|
||||
// high_order_grad begin;
|
||||
// if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
|
||||
// ExpandJ innermost graph or primitive first.
|
||||
if (CheckIfEmbedJ(func_graph)) {
|
||||
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later";
|
||||
return nullptr;
|
||||
}
|
||||
// high_order_grad end;
|
||||
|
||||
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now";
|
||||
auto newfg = ad::Grad(func_graph, resource);
|
||||
return NewValueNode(newfg);
|
||||
}
|
||||
|
||||
if (IsValueNode<Primitive>(vnode)) {
|
||||
return ExpandJPrimitive(vnode, resource);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
// Search all j nodes.
|
||||
GetJPrim(optimizer->resource()->manager());
|
||||
// Get j nodes that don't have embed j nodes.
|
||||
std::vector<CNodePtr> todo;
|
||||
// If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
|
||||
// ExpandJ innermost graph or primitive first.
|
||||
std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo),
|
||||
[](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); });
|
||||
// Expand j nodes that don't have embed j nodes.
|
||||
bool change = false;
|
||||
for (auto &j_node : todo) {
|
||||
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource());
|
||||
optimizer->resource()->manager()->Replace(j_node, expanded_j);
|
||||
change = true;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) {
|
||||
j_nodes_.clear();
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
std::vector<AnfNodePtr> &&toposet = TopoSort(fg->get_return());
|
||||
for (const auto &node : toposet) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||
j_nodes_.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,28 +31,17 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace internal {
|
||||
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource);
|
||||
} // namespace internal
|
||||
|
||||
// {prim::kPrimJ, C}
|
||||
class ExpandJPrim : public AnfVisitor {
|
||||
class ExpandJPrim {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
x_ = nullptr;
|
||||
AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node);
|
||||
if (x_ != nullptr) {
|
||||
TraceGuard guard(std::make_shared<TraceExpandJ>(node->debug_info()));
|
||||
auto j_node = internal::ExpandJ(x_, optimizer->resource());
|
||||
return j_node;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &node) override { x_ = node; }
|
||||
ExpandJPrim() = default;
|
||||
virtual ~ExpandJPrim() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
|
||||
void GetJPrim(const FuncGraphManagerPtr &manager);
|
||||
|
||||
private:
|
||||
ValueNodePtr x_{nullptr};
|
||||
std::vector<CNodePtr> j_nodes_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
@ -166,7 +167,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.mini_step_allgather_replace_,
|
||||
});
|
||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
|
||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||
|
||||
opt::OptPassConfig resolve_pass =
|
||||
|
@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
|
||||
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
||||
{"virtual_dataset", virtual_dataset},
|
||||
{"grad", grad},
|
||||
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
|
||||
{"resolve", resolve_pass},
|
||||
{"a_after_grad", a_after_grad},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
|
|
|
@ -88,14 +88,6 @@ class TestOptLib : public UT::Common {
|
|||
irpass::OptimizeIRPassLib irpass;
|
||||
};
|
||||
|
||||
TEST_F(TestOptLib, test_expendJ) {
|
||||
FuncGraphPtr before = getPyFun("test_expendJ");
|
||||
|
||||
ASSERT_TRUE(nullptr != before);
|
||||
|
||||
FuncGraphPtr after = RunSubs(before, std::vector<SubstitutionPtr>({irpass.expand_jprim_}));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_simplify_always_true_false) {
|
||||
FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1");
|
||||
FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2");
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "frontend/optimizer/cse_pass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "debug/draw.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -38,23 +39,24 @@ class TestOptOptimizer : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestOptOptimizer, test_step_opt) {
|
||||
FuncGraphPtr before = getPyFun("test_expendJ");
|
||||
FuncGraphPtr before = getPyFun("test_expandJ");
|
||||
|
||||
ASSERT_TRUE(nullptr != before);
|
||||
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
std::shared_ptr<Optimizer> optimizer = Optimizer::MakeOptimizer("ut_test", res,
|
||||
{{"main",
|
||||
{
|
||||
// Branch culling
|
||||
irpass.switch_simplify_,
|
||||
std::shared_ptr<Optimizer> optimizer =
|
||||
Optimizer::MakeOptimizer("ut_test", res,
|
||||
{{"main",
|
||||
{
|
||||
// Branch culling
|
||||
irpass.switch_simplify_,
|
||||
|
||||
// Safe inlining
|
||||
irpass.arithmetic_simplify_,
|
||||
irpass.inline_,
|
||||
}},
|
||||
{"grad", {irpass.expand_jprim_}},
|
||||
{"cse", OptPassConfig(CSEPass(false))}},
|
||||
true);
|
||||
// Safe inlining
|
||||
irpass.arithmetic_simplify_,
|
||||
irpass.inline_,
|
||||
}},
|
||||
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
|
||||
{"cse", OptPassConfig(CSEPass(false))}},
|
||||
true);
|
||||
EXPECT_TRUE(optimizer.get() != nullptr);
|
||||
|
||||
auto after = optimizer->step(before);
|
||||
|
|
|
@ -133,8 +133,8 @@ def cost(x):
|
|||
J = Primitive('J')
|
||||
|
||||
|
||||
def test_expendJ(x):
|
||||
""" test_expendJ """
|
||||
def test_expandJ(x):
|
||||
""" test_expandJ """
|
||||
return J(cost)(x)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue