!47727 Clone the called graph when a primal graph calls another primal graph

Merge pull request !47727 from YuJianfeng/expand_j
This commit is contained in:
i-robot 2023-01-12 10:08:51 +00:00 committed by Gitee
commit 7ec6b71232
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 58 additions and 17 deletions

View File

@ -196,26 +196,10 @@ FuncGraphVector GradMultiFuncGraph(const FuncGraphVector &func_graphs, const opt
const auto &resources = optimizer->resource();
auto manager_ptr = resources->manager();
MS_EXCEPTION_IF_NULL(manager_ptr);
// Graded func_graph should not call each other;
for (const auto &func_graph : func_graphs) {
const auto &used_total = func_graph->func_graphs_used_total();
FuncGraphPtr used_fg;
bool used = std::any_of(func_graphs.cbegin(), func_graphs.end(), [&used_total, &used_fg](const FuncGraphPtr &fg) {
if (used_total.contains(fg)) {
used_fg = fg;
return true;
}
return false;
});
if (used) {
MS_LOG(EXCEPTION) << "Grad func_graph: " << func_graph->ToString()
<< " use another will be graded func_graph: " << used_fg->ToString();
}
}
for (const auto &func_graph : func_graphs) {
manager_ptr->AddFuncGraph(func_graph);
}
FuncGraphVector before_grad_fgs;
if (optimizer->is_first_order_j()) {
lift_fv_before_grad = true;

View File

@ -18,6 +18,7 @@
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "pipeline/pynative/pynative_execute.h"
#include "ir/func_graph_cloner.h"
namespace mindspore {
namespace opt {
@ -50,6 +51,29 @@ AnfNodePtrList ExpandMultiJ(const FuncGraphVector &func_graphs, const OptimizerP
}
} // namespace internal
void ExpandJPrim::CloneUsedPrimalGraph(const FuncGraphManagerPtr &manager, FuncGraphVector *func_graphs) {
MS_EXCEPTION_IF_NULL(func_graphs);
size_t func_graphs_size = func_graphs->size();
for (size_t i = 0; i < func_graphs_size; ++i) {
const auto &used_total = (*func_graphs)[i]->func_graphs_used_total();
for (size_t j = 0; j < func_graphs_size; ++j) {
auto fg_j = (*func_graphs)[j];
if (j == i || !used_total.contains(fg_j)) {
continue;
}
auto new_fg = BasicClone(fg_j);
for (auto &j_node : prim_nodes_) {
auto j_node_fg = GetValueNode<FuncGraphPtr>(j_node->input(1));
if (j_node_fg == nullptr || j_node_fg != fg_j) {
continue;
}
manager->Replace(j_node->input(1), NewValueNode(new_fg));
}
(*func_graphs)[j] = new_fg;
}
}
}
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
// Check whether need to eliminate forward cnodes in pynative mode.
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
@ -85,6 +109,8 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
change = true;
}
}
CloneUsedPrimalGraph(manager, &func_graphs);
auto grad_func_graphs = internal::ExpandMultiJ(func_graphs, optimizer);
for (const auto &j_node_index_iter : j_node_to_index_map) {
const auto &j_node = j_node_index_iter.first;

View File

@ -36,6 +36,9 @@ class ExpandJPrim : public ExpandMetaFgPrim {
ExpandJPrim() { prim_ = prim::kPrimJ; }
virtual ~ExpandJPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override;
private:
void CloneUsedPrimalGraph(const FuncGraphManagerPtr &manager, FuncGraphVector *func_graphs);
};
using ExpandJPrimPtr = std::shared_ptr<ExpandJPrim>;
} // namespace irpass

View File

@ -1059,3 +1059,31 @@ def test_get_grad_wrap_with_msfunction_graph():
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_wrap_with_msfunction_get_grad(x, y, z)
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_primal_graph_call_others():
"""
Features: Auto grad.
Description: Two graph need to take a derivative and one calls the other graph.
Expectation: Get the correct gradient.
"""
def f(x, y):
return x + y
def g(x, y):
return f(x, y) * y
@jit
def net(x, y):
a = grad(f)(x, y)
b = grad(g)(x, y)
return a + b
x = Tensor(np.array([1, 2]).astype(np.float32))
y = Tensor(np.array([3, 4]).astype(np.float32))
expected = Tensor(np.array([4, 5]).astype(np.float32))
output = net(x, y)
assert np.allclose(output.asnumpy(), expected.asnumpy())