forked from mindspore-Ecosystem/mindspore
!47727 Clone the called graph when a primal graph calls another primal graph
Merge pull request !47727 from YuJianfeng/expand_j
This commit is contained in:
commit
7ec6b71232
|
@ -196,26 +196,10 @@ FuncGraphVector GradMultiFuncGraph(const FuncGraphVector &func_graphs, const opt
|
||||||
const auto &resources = optimizer->resource();
|
const auto &resources = optimizer->resource();
|
||||||
auto manager_ptr = resources->manager();
|
auto manager_ptr = resources->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager_ptr);
|
MS_EXCEPTION_IF_NULL(manager_ptr);
|
||||||
// Graded func_graph should not call each other;
|
|
||||||
for (const auto &func_graph : func_graphs) {
|
|
||||||
const auto &used_total = func_graph->func_graphs_used_total();
|
|
||||||
FuncGraphPtr used_fg;
|
|
||||||
bool used = std::any_of(func_graphs.cbegin(), func_graphs.end(), [&used_total, &used_fg](const FuncGraphPtr &fg) {
|
|
||||||
if (used_total.contains(fg)) {
|
|
||||||
used_fg = fg;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
if (used) {
|
|
||||||
MS_LOG(EXCEPTION) << "Grad func_graph: " << func_graph->ToString()
|
|
||||||
<< " use another will be graded func_graph: " << used_fg->ToString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto &func_graph : func_graphs) {
|
for (const auto &func_graph : func_graphs) {
|
||||||
manager_ptr->AddFuncGraph(func_graph);
|
manager_ptr->AddFuncGraph(func_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphVector before_grad_fgs;
|
FuncGraphVector before_grad_fgs;
|
||||||
if (optimizer->is_first_order_j()) {
|
if (optimizer->is_first_order_j()) {
|
||||||
lift_fv_before_grad = true;
|
lift_fv_before_grad = true;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||||
#include "pipeline/pynative/pynative_execute.h"
|
#include "pipeline/pynative/pynative_execute.h"
|
||||||
|
#include "ir/func_graph_cloner.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -50,6 +51,29 @@ AnfNodePtrList ExpandMultiJ(const FuncGraphVector &func_graphs, const OptimizerP
|
||||||
}
|
}
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
|
void ExpandJPrim::CloneUsedPrimalGraph(const FuncGraphManagerPtr &manager, FuncGraphVector *func_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graphs);
|
||||||
|
size_t func_graphs_size = func_graphs->size();
|
||||||
|
for (size_t i = 0; i < func_graphs_size; ++i) {
|
||||||
|
const auto &used_total = (*func_graphs)[i]->func_graphs_used_total();
|
||||||
|
for (size_t j = 0; j < func_graphs_size; ++j) {
|
||||||
|
auto fg_j = (*func_graphs)[j];
|
||||||
|
if (j == i || !used_total.contains(fg_j)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto new_fg = BasicClone(fg_j);
|
||||||
|
for (auto &j_node : prim_nodes_) {
|
||||||
|
auto j_node_fg = GetValueNode<FuncGraphPtr>(j_node->input(1));
|
||||||
|
if (j_node_fg == nullptr || j_node_fg != fg_j) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
manager->Replace(j_node->input(1), NewValueNode(new_fg));
|
||||||
|
}
|
||||||
|
(*func_graphs)[j] = new_fg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||||
// Check whether need to eliminate forward cnodes in pynative mode.
|
// Check whether need to eliminate forward cnodes in pynative mode.
|
||||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||||
|
@ -85,6 +109,8 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
|
||||||
change = true;
|
change = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
CloneUsedPrimalGraph(manager, &func_graphs);
|
||||||
|
|
||||||
auto grad_func_graphs = internal::ExpandMultiJ(func_graphs, optimizer);
|
auto grad_func_graphs = internal::ExpandMultiJ(func_graphs, optimizer);
|
||||||
for (const auto &j_node_index_iter : j_node_to_index_map) {
|
for (const auto &j_node_index_iter : j_node_to_index_map) {
|
||||||
const auto &j_node = j_node_index_iter.first;
|
const auto &j_node = j_node_index_iter.first;
|
||||||
|
|
|
@ -36,6 +36,9 @@ class ExpandJPrim : public ExpandMetaFgPrim {
|
||||||
ExpandJPrim() { prim_ = prim::kPrimJ; }
|
ExpandJPrim() { prim_ = prim::kPrimJ; }
|
||||||
virtual ~ExpandJPrim() = default;
|
virtual ~ExpandJPrim() = default;
|
||||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override;
|
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void CloneUsedPrimalGraph(const FuncGraphManagerPtr &manager, FuncGraphVector *func_graphs);
|
||||||
};
|
};
|
||||||
using ExpandJPrimPtr = std::shared_ptr<ExpandJPrim>;
|
using ExpandJPrimPtr = std::shared_ptr<ExpandJPrim>;
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
|
|
|
@ -1059,3 +1059,31 @@ def test_get_grad_wrap_with_msfunction_graph():
|
||||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||||
real_grad = grad_wrap_with_msfunction_get_grad(x, y, z)
|
real_grad = grad_wrap_with_msfunction_get_grad(x, y, z)
|
||||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_primal_graph_call_others():
|
||||||
|
"""
|
||||||
|
Features: Auto grad.
|
||||||
|
Description: Two graph need to take a derivative and one calls the other graph.
|
||||||
|
Expectation: Get the correct gradient.
|
||||||
|
"""
|
||||||
|
def f(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def g(x, y):
|
||||||
|
return f(x, y) * y
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def net(x, y):
|
||||||
|
a = grad(f)(x, y)
|
||||||
|
b = grad(g)(x, y)
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([3, 4]).astype(np.float32))
|
||||||
|
expected = Tensor(np.array([4, 5]).astype(np.float32))
|
||||||
|
output = net(x, y)
|
||||||
|
assert np.allclose(output.asnumpy(), expected.asnumpy())
|
||||||
|
|
Loading…
Reference in New Issue