!63637 Fixed dict grad issue in pynative and jit scenario.
Merge pull request !63637 from Margaret_wangrui/pynative_jit_dict
This commit is contained in:
commit
939526f29c
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -62,6 +62,7 @@
|
|||
#include "include/backend/debug/data_dump/dump_json_parser.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "include/backend/debug/profiler/profiling.h"
|
||||
#include "frontend/optimizer/fallback_rewriter.h"
|
||||
#if defined(__linux__) && defined(WITH_BACKEND)
|
||||
#include "include/backend/distributed/cluster/cluster_context.h"
|
||||
#include "include/backend/distributed/ps/ps_context.h"
|
||||
|
@ -967,6 +968,18 @@ bool GetJitBpropGraph(const ResourcePtr &resource) {
|
|||
return pynative::PyNativeExecutor::GetInstance()->grad_executor()->jit()->GetJitGradGraph(resource);
|
||||
}
|
||||
|
||||
bool RewriterAfterOptAPassAfterJitBprop(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->set_not_convert_jit(false);
|
||||
(void)mindspore::opt::RewriterAfterOptA(func_graph, resource);
|
||||
UpdateArgsSpec(func_graph, resource);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EliminateSpecialOpNode(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
if (resource->manager() == nullptr) {
|
||||
|
@ -1660,6 +1673,9 @@ std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
|
|||
// Eliminate forward cnode for grad graph
|
||||
(void)actions.emplace_back(std::make_pair(kGetJitBpropGraph, GetJitBpropGraph));
|
||||
|
||||
// Rewriter(dict convert pyexecute) after jit bprop.
|
||||
(void)actions.emplace_back(std::make_pair(kRewriterAfterJitBprop, RewriterAfterOptAPassAfterJitBprop));
|
||||
|
||||
// Eliminate the virtual mirror node
|
||||
(void)actions.emplace_back(std::make_pair(kEliminateSpecialOpNode, EliminateSpecialOpNode));
|
||||
(void)actions.emplace_back(std::make_pair(kValidate, ValidateAction));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -94,7 +94,6 @@ using CompileGraphs = compile::CompileGraphs;
|
|||
using abstract::AnalysisResult;
|
||||
using mindspore::abstract::AnalysisContextPtr;
|
||||
using mindspore::validator::Validate;
|
||||
namespace {
|
||||
void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
|
@ -105,7 +104,6 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource)
|
|||
[](const AnfNodePtr &p) { return p->abstract(); });
|
||||
resource->set_args_abs(args_abs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PyInterpretToExecutePass(const ResourcePtr &resource) {
|
||||
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
|
||||
|
@ -149,6 +147,11 @@ bool TransformTopGraphPass(const ResourcePtr &resource) {
|
|||
}
|
||||
|
||||
bool RewriterAfterOptAPass(const ResourcePtr &resource) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (context->not_convert_jit()) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -59,6 +59,7 @@ FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, co
|
|||
const std::vector<bool> &need_grad_flags);
|
||||
FuncGraphPtr JitBpropGraphPass(const ResourcePtr &resource, bool need_renormalize);
|
||||
FuncGraphPtr FinalBpropGraphPass(const ResourcePtr &resource, bool has_control_flow);
|
||||
void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -71,6 +71,7 @@ const char kPipelineSplit[] = "pipeline_split";
|
|||
const char kOptimize[] = "optimize";
|
||||
const char kAutoMonadReorder[] = "auto_monad_reorder";
|
||||
const char kGetJitBpropGraph[] = "get_jit_bprop_graph";
|
||||
const char kRewriterAfterJitBprop[] = "rewriter_after_jit_bprop_graph";
|
||||
const char kEliminateSpecialOpNode[] = "eliminate_special_op_node";
|
||||
const char kValidate[] = "validate";
|
||||
const char kLoadMindir[] = "load_mindir";
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -1785,7 +1785,7 @@ AnfNodePtr AutoGradCellImpl::TraceShape(const FunctionNodePtr &fn, const ValuePt
|
|||
}
|
||||
return new_din;
|
||||
} else if (out_value->isa<ValueDictionary>()) {
|
||||
TraceShapeForDict(fn, out_value, out_abs, input_tensor, din);
|
||||
return TraceShapeForDict(fn, out_value, out_abs, input_tensor, din);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get non tensor input " << out_value->ToString();
|
||||
return BuildSpecialNode(ad_param()->tape_, out_value, out_abs, SpecialType::kZerosLikeType);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -149,6 +149,7 @@ void RegMsContext(const py::module *m) {
|
|||
"Return whether this MindSpore package supports specified device.")
|
||||
.def("_enable_cell_recompute", &mindspore::MsContext::EnableCellRecompute, "Set a cell to be recomputed.")
|
||||
.def("load_plugin_error", &mindspore::MsContext::GetLoadPluginErrorStr,
|
||||
"Return error message when loading plugins for this MindSpore package.");
|
||||
"Return error message when loading plugins for this MindSpore package.")
|
||||
.def("_set_not_convert_jit", &mindspore::MsContext::set_not_convert_jit, "Set not convert jit.");
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -266,6 +266,9 @@ class MS_CORE_API MsContext {
|
|||
|
||||
std::string GetLoadPluginErrorStr() const { return load_plugin_error_(); }
|
||||
|
||||
void set_not_convert_jit(bool not_convert_jit) { not_convert_jit_ = not_convert_jit; }
|
||||
bool not_convert_jit() { return not_convert_jit_; }
|
||||
|
||||
private:
|
||||
void RefreshExecutionMode();
|
||||
void RefreshMemoryOffload();
|
||||
|
@ -297,6 +300,7 @@ class MS_CORE_API MsContext {
|
|||
static std::map<std::string, std::string> &PluginPathMap();
|
||||
enum CellReuseLevel cell_reuse_level_ = CellReuseLevel::kNoCellReuse;
|
||||
bool cell_recompute_{false};
|
||||
bool not_convert_jit_{false};
|
||||
};
|
||||
|
||||
// set method implementation for type bool/int/uint32_t/float/std::string
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -31,7 +31,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|||
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
||||
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
|
||||
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_, StarredGetItem_,\
|
||||
StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_
|
||||
StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
||||
from mindspore.common.api import _add_flags, _core
|
||||
|
@ -380,6 +380,7 @@ class GradOperation(GradOperation_):
|
|||
out = _grads_divided_by_device_num_if_recomputation(out)
|
||||
return out
|
||||
else:
|
||||
MSContext.get_instance()._set_not_convert_jit(True)
|
||||
grad_.pynative_ = True
|
||||
if not _pynative_executor.enable_grad():
|
||||
raise RuntimeError("In no_grad context, you can not calculate gradient")
|
||||
|
@ -610,6 +611,7 @@ class _Grad(GradOperation_):
|
|||
return out, res[1:]
|
||||
return out
|
||||
else:
|
||||
MSContext.get_instance()._set_not_convert_jit(True)
|
||||
if not _pynative_executor.enable_grad():
|
||||
raise RuntimeError("In no_grad context, you can not calculate gradient")
|
||||
grad_.pynative_ = True
|
||||
|
|
|
@ -1160,3 +1160,50 @@ def test_dict_inner_method_overrrided_3():
|
|||
return obj.to_tuple()
|
||||
ms_out = foo()
|
||||
assert ms_out == (1, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_jit_dict_grad():
|
||||
"""
|
||||
Feature: Return dict in forward graph.
|
||||
Description: Support grad for dict return in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms.jit
|
||||
def dict_net(a):
|
||||
x = {'a': a, 'b': 2}
|
||||
return x
|
||||
|
||||
ms.set_context(mode=ms.PYNATIVE_MODE)
|
||||
out = ops.grad(dict_net)(ms.Tensor([1]))
|
||||
assert out == 1
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_jit_dict_grad_2():
|
||||
"""
|
||||
Feature: Return dict in forward graph.
|
||||
Description: Support grad for dict return in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms.jit
|
||||
def dict_net(a):
|
||||
x = {'a': a, 'b': 2}
|
||||
return x
|
||||
|
||||
ms.set_context(mode=ms.PYNATIVE_MODE)
|
||||
grad = ops.GradOperation()
|
||||
out = grad(dict_net)(ms.Tensor([1]))
|
||||
assert out == 1
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -122,3 +122,37 @@ def test_network_with_dict_output():
|
|||
ms_grad = GradOfFirstInput(ms_net, True)
|
||||
grad_out = ms_grad(Tensor(x), out)
|
||||
assert np.allclose(x, grad_out.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_jit_network_with_dict_output():
|
||||
"""
|
||||
Feature: Test sens dict in jit
|
||||
Description: Net out is dict in jit
|
||||
Expectation: Success
|
||||
"""
|
||||
class DicNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
@jit
|
||||
def construct(self, x):
|
||||
y = self.relu(x)
|
||||
out = {'a': y}
|
||||
return out
|
||||
|
||||
x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]])
|
||||
ms_net = DicNet()
|
||||
# No sens
|
||||
ms_grad = GradOfFirstInput(ms_net, False)
|
||||
grad_out = ms_grad(Tensor(x))
|
||||
assert np.allclose(np.ones_like(x), grad_out.asnumpy())
|
||||
|
||||
# Have sens
|
||||
out = ms_net(Tensor(x))
|
||||
ms_grad = GradOfFirstInput(ms_net, True)
|
||||
grad_out = ms_grad(Tensor(x), out)
|
||||
assert np.allclose(x, grad_out.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue