!45706 报错信息修改(用户自定义网络,__call__用ms_function装饰,求反向报错)

Merge pull request !45706 from ligan/GradOperation
This commit is contained in:
i-robot 2022-11-24 07:07:58 +00:00 committed by Gitee
commit 88213e5a68
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 60 additions and 0 deletions

View File

@ -19,6 +19,7 @@
#include "frontend/operator/composite/composite.h"
#include <algorithm>
#include <tuple>
#include <regex>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "abstract/abstract_value.h"
@ -33,6 +34,7 @@
#include "pipeline/jit/debug/trace.h"
#include "utils/ms_context.h"
#include "include/common/utils/utils.h"
#include "pipeline/jit/parse/resolve.h"
namespace mindspore {
// namespace to support composite operators definition
@ -863,6 +865,22 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got "
<< args_spec_list[0]->ToString();
}
if (fn->isa<abstract::PartialAbstractClosure>()) {
auto partial_abs = fn->cast<abstract::PartialAbstractClosurePtr>();
const auto &args = partial_abs->args();
if (!args.empty()) {
auto value = args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<parse::MsClassObject>()) {
auto value_obj = dyn_cast_ptr<parse::MsClassObject>(value);
auto obj_name = std::regex_replace(value_obj->name(), std::regex("MsClassObject:"), "");
MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type "
<< "object, but got object with jit_class type" << obj_name << ".\n'GradOperation' "
<< "does not support '__call__' magic methods as object.\nFor more details, "
<< "please refer to https://www.mindspore.cn/search?inputValue=Gradoperation";
}
}
}
// Waiting for implementation.
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);

View File

@ -22,6 +22,7 @@ from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.api import jit, jit_class
grad_all = C.GradOperation(get_all=True)
grad_by_list = C.GradOperation(get_by_list=True)
@ -491,6 +492,47 @@ def test_grad_net_is_none():
assert "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got" in str(e)
def test_grad_call_self_net():
"""
Feature: Custom cell use GradOperation.
Description: GradOperation does not support __call__ magic methods as object.
Expectation: Raise an error.
"""
context.set_context(mode=context.GRAPH_MODE)
@jit_class
class Net:
def __init__(self):
self.weight = Parameter([10, 10], name='v')
@jit
def __call__(self, x):
a = self.func(x)
out = self.func(a)
return out
def func(self, x):
self.weight = 2 * self.weight
return self.weight * x
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super().__init__()
self.grad_op = ops.GradOperation()
self.net = net
def construct(self, x):
grad_net = self.grad_op(self.net)
grad = grad_net(x)
return grad
x = Tensor(np.array([1.0, 1.0], dtype=np.float32))
try:
GradNetWrtX(Net())(x)
except Exception as e:
assert "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type object, "\
"but got object with jit_class type 'Net'." in str(e)
def test_grad_missing_net():
context.set_context(mode=context.GRAPH_MODE)