!45706 报错信息修改(用户自定义网络,__call__用ms_function装饰,求反向报错)
Merge pull request !45706 from ligan/GradOperation
This commit is contained in:
commit
88213e5a68
|
@ -19,6 +19,7 @@
|
|||
#include "frontend/operator/composite/composite.h"
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include <regex>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
@ -33,6 +34,7 @@
|
|||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
|
@ -863,6 +865,22 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got "
|
||||
<< args_spec_list[0]->ToString();
|
||||
}
|
||||
if (fn->isa<abstract::PartialAbstractClosure>()) {
|
||||
auto partial_abs = fn->cast<abstract::PartialAbstractClosurePtr>();
|
||||
const auto &args = partial_abs->args();
|
||||
if (!args.empty()) {
|
||||
auto value = args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<parse::MsClassObject>()) {
|
||||
auto value_obj = dyn_cast_ptr<parse::MsClassObject>(value);
|
||||
auto obj_name = std::regex_replace(value_obj->name(), std::regex("MsClassObject:"), "");
|
||||
MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type "
|
||||
<< "object, but got object with jit_class type" << obj_name << ".\n'GradOperation' "
|
||||
<< "does not support '__call__' magic methods as object.\nFor more details, "
|
||||
<< "please refer to https://www.mindspore.cn/search?inputValue=Gradoperation";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Waiting for implementation.
|
||||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.api import jit, jit_class
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
grad_by_list = C.GradOperation(get_by_list=True)
|
||||
|
@ -491,6 +492,47 @@ def test_grad_net_is_none():
|
|||
assert "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got" in str(e)
|
||||
|
||||
|
||||
def test_grad_call_self_net():
|
||||
"""
|
||||
Feature: Custom cell use GradOperation.
|
||||
Description: GradOperation does not support __call__ magic methods as object.
|
||||
Expectation: Raise an error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
@jit_class
|
||||
class Net:
|
||||
def __init__(self):
|
||||
self.weight = Parameter([10, 10], name='v')
|
||||
|
||||
@jit
|
||||
def __call__(self, x):
|
||||
a = self.func(x)
|
||||
out = self.func(a)
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
self.weight = 2 * self.weight
|
||||
return self.weight * x
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.grad_op = ops.GradOperation()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
grad_net = self.grad_op(self.net)
|
||||
grad = grad_net(x)
|
||||
return grad
|
||||
|
||||
x = Tensor(np.array([1.0, 1.0], dtype=np.float32))
|
||||
try:
|
||||
GradNetWrtX(Net())(x)
|
||||
except Exception as e:
|
||||
assert "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type object, "\
|
||||
"but got object with jit_class type 'Net'." in str(e)
|
||||
|
||||
|
||||
def test_grad_missing_net():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
|
Loading…
Reference in New Issue