From 773959fbd39620f6b2feb949181600d146bf1493 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Mon, 28 Jun 2021 21:11:27 +0800 Subject: [PATCH] Add a experimental bprop directory to store new bprop --- mindspore/ccsrc/utils/primitive_utils.cc | 4 ++++ mindspore/ops/_grad_experimental/__init__.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 mindspore/ops/_grad_experimental/__init__.py diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc index b2608136e09..a0e991cfa10 100644 --- a/mindspore/ccsrc/utils/primitive_utils.cc +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -28,7 +28,11 @@ namespace mindspore { py::function GetBpropFunctionByObj(py::object obj) { static const std::string get_bprop_fn = "get_bprop_fn"; static const std::string ad_module = "mindspore.ops._grad"; + static const std::string ad_experimental_module = "mindspore.ops._grad_experimental"; py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj); + if (!fn || py::isinstance(fn)) { + fn = parse::python_adapter::GetPyFn(ad_experimental_module, get_bprop_fn)(obj); + } return fn; } diff --git a/mindspore/ops/_grad_experimental/__init__.py b/mindspore/ops/_grad_experimental/__init__.py new file mode 100644 index 00000000000..5fabf43c0bf --- /dev/null +++ b/mindspore/ops/_grad_experimental/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""grad experimental impl.""" +from .._grad.grad_base import get_bprop_fn + +__all__ = ['get_bprop_fn']