forked from mindspore-Ecosystem/mindspore
!47727 Clone the called graph when a primal graph calls another primal graph
Merge pull request !47727 from YuJianfeng/expand_j
This commit is contained in:
commit
7ec6b71232
|
@ -196,26 +196,10 @@ FuncGraphVector GradMultiFuncGraph(const FuncGraphVector &func_graphs, const opt
|
|||
const auto &resources = optimizer->resource();
|
||||
auto manager_ptr = resources->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager_ptr);
|
||||
// Graded func_graph should not call each other;
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
const auto &used_total = func_graph->func_graphs_used_total();
|
||||
FuncGraphPtr used_fg;
|
||||
bool used = std::any_of(func_graphs.cbegin(), func_graphs.end(), [&used_total, &used_fg](const FuncGraphPtr &fg) {
|
||||
if (used_total.contains(fg)) {
|
||||
used_fg = fg;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (used) {
|
||||
MS_LOG(EXCEPTION) << "Grad func_graph: " << func_graph->ToString()
|
||||
<< " use another will be graded func_graph: " << used_fg->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
manager_ptr->AddFuncGraph(func_graph);
|
||||
}
|
||||
|
||||
FuncGraphVector before_grad_fgs;
|
||||
if (optimizer->is_first_order_j()) {
|
||||
lift_fv_before_grad = true;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -50,6 +51,29 @@ AnfNodePtrList ExpandMultiJ(const FuncGraphVector &func_graphs, const OptimizerP
|
|||
}
|
||||
} // namespace internal
|
||||
|
||||
void ExpandJPrim::CloneUsedPrimalGraph(const FuncGraphManagerPtr &manager, FuncGraphVector *func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graphs);
|
||||
size_t func_graphs_size = func_graphs->size();
|
||||
for (size_t i = 0; i < func_graphs_size; ++i) {
|
||||
const auto &used_total = (*func_graphs)[i]->func_graphs_used_total();
|
||||
for (size_t j = 0; j < func_graphs_size; ++j) {
|
||||
auto fg_j = (*func_graphs)[j];
|
||||
if (j == i || !used_total.contains(fg_j)) {
|
||||
continue;
|
||||
}
|
||||
auto new_fg = BasicClone(fg_j);
|
||||
for (auto &j_node : prim_nodes_) {
|
||||
auto j_node_fg = GetValueNode<FuncGraphPtr>(j_node->input(1));
|
||||
if (j_node_fg == nullptr || j_node_fg != fg_j) {
|
||||
continue;
|
||||
}
|
||||
manager->Replace(j_node->input(1), NewValueNode(new_fg));
|
||||
}
|
||||
(*func_graphs)[j] = new_fg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
// Check whether need to eliminate forward cnodes in pynative mode.
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
|
@ -85,6 +109,8 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
|
|||
change = true;
|
||||
}
|
||||
}
|
||||
CloneUsedPrimalGraph(manager, &func_graphs);
|
||||
|
||||
auto grad_func_graphs = internal::ExpandMultiJ(func_graphs, optimizer);
|
||||
for (const auto &j_node_index_iter : j_node_to_index_map) {
|
||||
const auto &j_node = j_node_index_iter.first;
|
||||
|
|
|
@ -36,6 +36,9 @@ class ExpandJPrim : public ExpandMetaFgPrim {
|
|||
ExpandJPrim() { prim_ = prim::kPrimJ; }
|
||||
virtual ~ExpandJPrim() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override;
|
||||
|
||||
private:
|
||||
void CloneUsedPrimalGraph(const FuncGraphManagerPtr &manager, FuncGraphVector *func_graphs);
|
||||
};
|
||||
using ExpandJPrimPtr = std::shared_ptr<ExpandJPrim>;
|
||||
} // namespace irpass
|
||||
|
|
|
@ -1059,3 +1059,31 @@ def test_get_grad_wrap_with_msfunction_graph():
|
|||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_wrap_with_msfunction_get_grad(x, y, z)
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_primal_graph_call_others():
|
||||
"""
|
||||
Features: Auto grad.
|
||||
Description: Two graph need to take a derivative and one calls the other graph.
|
||||
Expectation: Get the correct gradient.
|
||||
"""
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
def g(x, y):
|
||||
return f(x, y) * y
|
||||
|
||||
@jit
|
||||
def net(x, y):
|
||||
a = grad(f)(x, y)
|
||||
b = grad(g)(x, y)
|
||||
return a + b
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
y = Tensor(np.array([3, 4]).astype(np.float32))
|
||||
expected = Tensor(np.array([4, 5]).astype(np.float32))
|
||||
output = net(x, y)
|
||||
assert np.allclose(output.asnumpy(), expected.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue