!63637 Fixed dict grad issue in pynative and jit scenario.

Merge pull request !63637 from Margaret_wangrui/pynative_jit_dict
This commit is contained in:
i-robot 2024-01-17 07:05:39 +00:00 committed by Gitee
commit 939526f29c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 123 additions and 14 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 Huawei Technologies Co., Ltd
* Copyright 2019-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -62,6 +62,7 @@
#include "include/backend/debug/data_dump/dump_json_parser.h"
#include "backend/common/graph_kernel/graph_kernel_flags.h"
#include "include/backend/debug/profiler/profiling.h"
#include "frontend/optimizer/fallback_rewriter.h"
#if defined(__linux__) && defined(WITH_BACKEND)
#include "include/backend/distributed/cluster/cluster_context.h"
#include "include/backend/distributed/ps/ps_context.h"
@ -967,6 +968,18 @@ bool GetJitBpropGraph(const ResourcePtr &resource) {
return pynative::PyNativeExecutor::GetInstance()->grad_executor()->jit()->GetJitGradGraph(resource);
}
bool RewriterAfterOptAPassAfterJitBprop(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_not_convert_jit(false);
(void)mindspore::opt::RewriterAfterOptA(func_graph, resource);
UpdateArgsSpec(func_graph, resource);
return true;
}
bool EliminateSpecialOpNode(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
if (resource->manager() == nullptr) {
@ -1660,6 +1673,9 @@ std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
// Eliminate forward cnode for grad graph
(void)actions.emplace_back(std::make_pair(kGetJitBpropGraph, GetJitBpropGraph));
// Rewriter(dict convert pyexecute) after jit bprop.
(void)actions.emplace_back(std::make_pair(kRewriterAfterJitBprop, RewriterAfterOptAPassAfterJitBprop));
// Eliminate the virtual mirror node
(void)actions.emplace_back(std::make_pair(kEliminateSpecialOpNode, EliminateSpecialOpNode));
(void)actions.emplace_back(std::make_pair(kValidate, ValidateAction));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 Huawei Technologies Co., Ltd
* Copyright 2019-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -94,7 +94,6 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;
using mindspore::validator::Validate;
namespace {
void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(resource);
@ -105,7 +104,6 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource)
[](const AnfNodePtr &p) { return p->abstract(); });
resource->set_args_abs(args_abs);
}
} // namespace
bool PyInterpretToExecutePass(const ResourcePtr &resource) {
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
@ -149,6 +147,11 @@ bool TransformTopGraphPass(const ResourcePtr &resource) {
}
bool RewriterAfterOptAPass(const ResourcePtr &resource) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->not_convert_jit()) {
return true;
}
MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -59,6 +59,7 @@ FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, co
const std::vector<bool> &need_grad_flags);
FuncGraphPtr JitBpropGraphPass(const ResourcePtr &resource, bool need_renormalize);
FuncGraphPtr FinalBpropGraphPass(const ResourcePtr &resource, bool has_control_flow);
void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource);
} // namespace pipeline
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 Huawei Technologies Co., Ltd
* Copyright 2019-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -71,6 +71,7 @@ const char kPipelineSplit[] = "pipeline_split";
const char kOptimize[] = "optimize";
const char kAutoMonadReorder[] = "auto_monad_reorder";
const char kGetJitBpropGraph[] = "get_jit_bprop_graph";
const char kRewriterAfterJitBprop[] = "rewriter_after_jit_bprop_graph";
const char kEliminateSpecialOpNode[] = "eliminate_special_op_node";
const char kValidate[] = "validate";
const char kLoadMindir[] = "load_mindir";

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2022-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -1785,7 +1785,7 @@ AnfNodePtr AutoGradCellImpl::TraceShape(const FunctionNodePtr &fn, const ValuePt
}
return new_din;
} else if (out_value->isa<ValueDictionary>()) {
TraceShapeForDict(fn, out_value, out_abs, input_tensor, din);
return TraceShapeForDict(fn, out_value, out_abs, input_tensor, din);
}
MS_LOG(DEBUG) << "Get non tensor input " << out_value->ToString();
return BuildSpecialNode(ad_param()->tape_, out_value, out_abs, SpecialType::kZerosLikeType);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 Huawei Technologies Co., Ltd
* Copyright 2020-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -149,6 +149,7 @@ void RegMsContext(const py::module *m) {
"Return whether this MindSpore package supports specified device.")
.def("_enable_cell_recompute", &mindspore::MsContext::EnableCellRecompute, "Set a cell to be recomputed.")
.def("load_plugin_error", &mindspore::MsContext::GetLoadPluginErrorStr,
"Return error message when loading plugins for this MindSpore package.");
"Return error message when loading plugins for this MindSpore package.")
.def("_set_not_convert_jit", &mindspore::MsContext::set_not_convert_jit, "Set not convert jit.");
}
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 Huawei Technologies Co., Ltd
* Copyright 2019-2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -266,6 +266,9 @@ class MS_CORE_API MsContext {
std::string GetLoadPluginErrorStr() const { return load_plugin_error_(); }
void set_not_convert_jit(bool not_convert_jit) { not_convert_jit_ = not_convert_jit; }
bool not_convert_jit() { return not_convert_jit_; }
private:
void RefreshExecutionMode();
void RefreshMemoryOffload();
@ -297,6 +300,7 @@ class MS_CORE_API MsContext {
static std::map<std::string, std::string> &PluginPathMap();
enum CellReuseLevel cell_reuse_level_ = CellReuseLevel::kNoCellReuse;
bool cell_recompute_{false};
bool not_convert_jit_{false};
};
// set method implementation for type bool/int/uint32_t/float/std::string

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020-2023 Huawei Technologies Co., Ltd
# Copyright 2020-2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -31,7 +31,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_, StarredGetItem_,\
StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_
StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext
from mindspore.common import dtype as mstype
from mindspore.common.api import jit, _pynative_executor, _wrap_func
from mindspore.common.api import _add_flags, _core
@ -380,6 +380,7 @@ class GradOperation(GradOperation_):
out = _grads_divided_by_device_num_if_recomputation(out)
return out
else:
MSContext.get_instance()._set_not_convert_jit(True)
grad_.pynative_ = True
if not _pynative_executor.enable_grad():
raise RuntimeError("In no_grad context, you can not calculate gradient")
@ -610,6 +611,7 @@ class _Grad(GradOperation_):
return out, res[1:]
return out
else:
MSContext.get_instance()._set_not_convert_jit(True)
if not _pynative_executor.enable_grad():
raise RuntimeError("In no_grad context, you can not calculate gradient")
grad_.pynative_ = True

View File

@ -1160,3 +1160,50 @@ def test_dict_inner_method_overrrided_3():
return obj.to_tuple()
ms_out = foo()
assert ms_out == (1, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_pynative_jit_dict_grad():
"""
Feature: Return dict in forward graph.
Description: Support grad for dict return in pynative mode.
Expectation: No exception.
"""
@ms.jit
def dict_net(a):
x = {'a': a, 'b': 2}
return x
ms.set_context(mode=ms.PYNATIVE_MODE)
out = ops.grad(dict_net)(ms.Tensor([1]))
assert out == 1
ms.set_context(mode=ms.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_pynative_jit_dict_grad_2():
"""
Feature: Return dict in forward graph.
Description: Support grad for dict return in pynative mode.
Expectation: No exception.
"""
@ms.jit
def dict_net(a):
x = {'a': a, 'b': 2}
return x
ms.set_context(mode=ms.PYNATIVE_MODE)
grad = ops.GradOperation()
out = grad(dict_net)(ms.Tensor([1]))
assert out == 1
ms.set_context(mode=ms.GRAPH_MODE)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -122,3 +122,37 @@ def test_network_with_dict_output():
ms_grad = GradOfFirstInput(ms_net, True)
grad_out = ms_grad(Tensor(x), out)
assert np.allclose(x, grad_out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_jit_network_with_dict_output():
"""
Feature: Test sens dict in jit
Description: Net out is dict in jit
Expectation: Success
"""
class DicNet(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
@jit
def construct(self, x):
y = self.relu(x)
out = {'a': y}
return out
x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]])
ms_net = DicNet()
# No sens
ms_grad = GradOfFirstInput(ms_net, False)
grad_out = ms_grad(Tensor(x))
assert np.allclose(np.ones_like(x), grad_out.asnumpy())
# Have sens
out = ms_net(Tensor(x))
ms_grad = GradOfFirstInput(ms_net, True)
grad_out = ms_grad(Tensor(x), out)
assert np.allclose(x, grad_out.asnumpy())