diff --git a/mindspore/ccsrc/pipeline/jit/ps/action.cc b/mindspore/ccsrc/pipeline/jit/ps/action.cc index 0b41a6f0bc4..0ca21124eac 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/action.cc +++ b/mindspore/ccsrc/pipeline/jit/ps/action.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 Huawei Technologies Co., Ltd + * Copyright 2019-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ #include "include/backend/debug/data_dump/dump_json_parser.h" #include "backend/common/graph_kernel/graph_kernel_flags.h" #include "include/backend/debug/profiler/profiling.h" +#include "frontend/optimizer/fallback_rewriter.h" #if defined(__linux__) && defined(WITH_BACKEND) #include "include/backend/distributed/cluster/cluster_context.h" #include "include/backend/distributed/ps/ps_context.h" @@ -957,6 +958,18 @@ bool GetJitBpropGraph(const ResourcePtr &resource) { return pynative::PyNativeExecutor::GetInstance()->grad_executor()->jit()->GetJitGradGraph(resource); } +bool RewriterAfterOptAPassAfterJitBprop(const ResourcePtr &resource) { + MS_EXCEPTION_IF_NULL(resource); + FuncGraphPtr func_graph = resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_not_convert_jit(false); + (void)mindspore::opt::RewriterAfterOptA(func_graph, resource); + UpdateArgsSpec(func_graph, resource); + return true; +} + bool EliminateSpecialOpNode(const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(resource); if (resource->manager() == nullptr) { @@ -1648,6 +1661,9 @@ std::vector VmPipeline(const ResourcePtr &resource) { // Eliminate forward cnode for grad graph (void)actions.emplace_back(std::make_pair(kGetJitBpropGraph, GetJitBpropGraph)); + // Rewriter(dict convert pyexecute) after jit bprop. + (void)actions.emplace_back(std::make_pair(kRewriterAfterJitBprop, RewriterAfterOptAPassAfterJitBprop)); + // Eliminate the virtual mirror node (void)actions.emplace_back(std::make_pair(kEliminateSpecialOpNode, EliminateSpecialOpNode)); (void)actions.emplace_back(std::make_pair(kValidate, ValidateAction)); diff --git a/mindspore/ccsrc/pipeline/jit/ps/pass.cc b/mindspore/ccsrc/pipeline/jit/ps/pass.cc index 89cfd2c009e..7bf912d9ce2 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/ps/pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 Huawei Technologies Co., Ltd + * Copyright 2019-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,7 +94,6 @@ using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; using mindspore::validator::Validate; -namespace { void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(resource); @@ -105,7 +104,6 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource) [](const AnfNodePtr &p) { return p->abstract(); }); resource->set_args_abs(args_abs); } -} // namespace bool PyInterpretToExecutePass(const ResourcePtr &resource) { const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax); @@ -149,6 +147,11 @@ bool TransformTopGraphPass(const ResourcePtr &resource) { } bool RewriterAfterOptAPass(const ResourcePtr &resource) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + if (context->not_convert_jit()) { + return true; + } MS_EXCEPTION_IF_NULL(resource); FuncGraphPtr func_graph = resource->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/pipeline/jit/ps/pass.h b/mindspore/ccsrc/pipeline/jit/ps/pass.h index 4c1336f4bbf..b48defbb63a 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/pass.h +++ b/mindspore/ccsrc/pipeline/jit/ps/pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,7 @@ FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, co const std::vector &need_grad_flags); FuncGraphPtr JitBpropGraphPass(const ResourcePtr &resource, bool need_renormalize); FuncGraphPtr FinalBpropGraphPass(const ResourcePtr &resource, bool has_control_flow); +void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/ps/resource.h b/mindspore/ccsrc/pipeline/jit/ps/resource.h index ab93b651541..6b350059f4a 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/resource.h +++ b/mindspore/ccsrc/pipeline/jit/ps/resource.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 Huawei Technologies Co., Ltd + * Copyright 2019-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ const char kPipelineSplit[] = "pipeline_split"; const char kOptimize[] = "optimize"; const char kAutoMonadReorder[] = "auto_monad_reorder"; const char kGetJitBpropGraph[] = "get_jit_bprop_graph"; +const char kRewriterAfterJitBprop[] = "rewriter_after_jit_bprop_graph"; const char kEliminateSpecialOpNode[] = "eliminate_special_op_node"; const char kValidate[] = "validate"; const char kLoadMindir[] = "load_mindir"; diff --git a/mindspore/ccsrc/pipeline/pynative/grad/auto_grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/auto_grad.cc index ccc41f67358..a8ffb9ecbe9 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/auto_grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/auto_grad.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1785,7 +1785,7 @@ AnfNodePtr AutoGradCellImpl::TraceShape(const FunctionNodePtr &fn, const ValuePt } return new_din; } else if (out_value->isa()) { - TraceShapeForDict(fn, out_value, out_abs, input_tensor, din); + return TraceShapeForDict(fn, out_value, out_abs, input_tensor, din); } MS_LOG(DEBUG) << "Get non tensor input " << out_value->ToString(); return BuildSpecialNode(ad_param()->tape_, out_value, out_abs, SpecialType::kZerosLikeType); diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc index 33191e6983e..45bde6537df 100644 --- a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 Huawei Technologies Co., Ltd + * Copyright 2020-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -149,6 +149,7 @@ void RegMsContext(const py::module *m) { "Return whether this MindSpore package supports specified device.") .def("_enable_cell_recompute", &mindspore::MsContext::EnableCellRecompute, "Set a cell to be recomputed.") .def("load_plugin_error", &mindspore::MsContext::GetLoadPluginErrorStr, - "Return error message when loading plugins for this MindSpore package."); + "Return error message when loading plugins for this MindSpore package.") + .def("_set_not_convert_jit", &mindspore::MsContext::set_not_convert_jit, "Set not convert jit."); } } // namespace mindspore diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 66be3af1481..418fd967714 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 Huawei Technologies Co., Ltd + * Copyright 2019-2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -265,6 +265,9 @@ class MS_CORE_API MsContext { std::string GetLoadPluginErrorStr() const { return load_plugin_error_(); } + void set_not_convert_jit(bool not_convert_jit) { not_convert_jit_ = not_convert_jit; } + bool not_convert_jit() { return not_convert_jit_; } + private: void RefreshExecutionMode(); void RefreshMemoryOffload(); @@ -296,6 +299,7 @@ class MS_CORE_API MsContext { static std::map &PluginPathMap(); enum CellReuseLevel cell_reuse_level_ = CellReuseLevel::kNoCellReuse; bool cell_recompute_{false}; + bool not_convert_jit_{false}; }; // set method implementation for type bool/int/uint32_t/float/std::string diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index a4e041b55ce..78204544add 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2020-2023 Huawei Technologies Co., Ltd +# Copyright 2020-2024 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \ ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \ HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_, StarredGetItem_,\ - StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_ + StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext from mindspore.common import dtype as mstype from mindspore.common.api import jit, _pynative_executor, _wrap_func from mindspore.common.api import _add_flags, _core @@ -380,6 +380,7 @@ class GradOperation(GradOperation_): out = _grads_divided_by_device_num_if_recomputation(out) return out else: + MSContext.get_instance()._set_not_convert_jit(True) grad_.pynative_ = True if not _pynative_executor.enable_grad(): raise RuntimeError("In no_grad context, you can not calculate gradient") @@ -610,6 +611,7 @@ class _Grad(GradOperation_): return out, res[1:] return out else: + MSContext.get_instance()._set_not_convert_jit(True) if not _pynative_executor.enable_grad(): raise RuntimeError("In no_grad context, you can not calculate gradient") grad_.pynative_ = True diff --git a/tests/st/fallback/test_graph_fallback_runtime_dict.py b/tests/st/fallback/test_graph_fallback_runtime_dict.py index 57c0166e4b5..828b9be544e 100644 --- a/tests/st/fallback/test_graph_fallback_runtime_dict.py +++ b/tests/st/fallback/test_graph_fallback_runtime_dict.py @@ -1160,3 +1160,50 @@ def test_dict_inner_method_overrrided_3(): return obj.to_tuple() ms_out = foo() assert ms_out == (1, 2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_jit_dict_grad(): + """ + Feature: Return dict in forward graph. + Description: Support grad for dict return in pynative mode. + Expectation: No exception. + """ + + @ms.jit + def dict_net(a): + x = {'a': a, 'b': 2} + return x + + ms.set_context(mode=ms.PYNATIVE_MODE) + out = ops.grad(dict_net)(ms.Tensor([1])) + assert out == 1 + ms.set_context(mode=ms.GRAPH_MODE) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_jit_dict_grad_2(): + """ + Feature: Return dict in forward graph. + Description: Support grad for dict return in pynative mode. + Expectation: No exception. + """ + + @ms.jit + def dict_net(a): + x = {'a': a, 'b': 2} + return x + + ms.set_context(mode=ms.PYNATIVE_MODE) + grad = ops.GradOperation() + out = grad(dict_net)(ms.Tensor([1])) + assert out == 1 + ms.set_context(mode=ms.GRAPH_MODE) diff --git a/tests/st/pynative/grad/test_common_grad.py b/tests/st/pynative/grad/test_common_grad.py index bacfc04e891..1096abda260 100644 --- a/tests/st/pynative/grad/test_common_grad.py +++ b/tests/st/pynative/grad/test_common_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2024 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -122,3 +122,37 @@ def test_network_with_dict_output(): ms_grad = GradOfFirstInput(ms_net, True) grad_out = ms_grad(Tensor(x), out) assert np.allclose(x, grad_out.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_jit_network_with_dict_output(): + """ + Feature: Test sens dict in jit + Description: Net out is dict in jit + Expectation: Success + """ + class DicNet(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + + @jit + def construct(self, x): + y = self.relu(x) + out = {'a': y} + return out + + x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]]) + ms_net = DicNet() + # No sens + ms_grad = GradOfFirstInput(ms_net, False) + grad_out = ms_grad(Tensor(x)) + assert np.allclose(np.ones_like(x), grad_out.asnumpy()) + + # Have sens + out = ms_net(Tensor(x)) + ms_grad = GradOfFirstInput(ms_net, True) + grad_out = ms_grad(Tensor(x), out) + assert np.allclose(x, grad_out.asnumpy())