forked from mindspore-Ecosystem/mindspore
[dfx]optimize exception info print
This commit is contained in:
parent
209a6e1df6
commit
5acf3c8ade
|
@ -327,6 +327,31 @@ class Cell(Cell_):
|
|||
_pynative_exec.leave_construct(self)
|
||||
return output
|
||||
|
||||
def check_construct_args(self, *inputs, **kwargs):
|
||||
"""Check the args needed by the function construct"""
|
||||
if kwargs:
|
||||
raise ValueError("For 'graph' mode, the outermost network does not support passing "
|
||||
"variable key-value pair parameters.")
|
||||
positional_args = 0
|
||||
default_args = 0
|
||||
for value in inspect.signature(self.construct).parameters.values():
|
||||
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
return
|
||||
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
if value.default is inspect.Parameter.empty:
|
||||
positional_args += 1
|
||||
else:
|
||||
default_args += 1
|
||||
|
||||
if len(inputs) < positional_args:
|
||||
raise ValueError(
|
||||
f"The function construct need {positional_args} positional argument, but only provided {len(inputs)}.")
|
||||
|
||||
if len(inputs) > positional_args + default_args:
|
||||
raise ValueError(
|
||||
f"The function construct need {positional_args} positional argument and {default_args} default "
|
||||
f"argument, but provided {len(inputs)}")
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
if self.__class__.construct is Cell.construct:
|
||||
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
|
||||
|
@ -336,9 +361,7 @@ class Cell(Cell_):
|
|||
inputs = bound_args.args
|
||||
kwargs = bound_args.kwargs
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if kwargs:
|
||||
raise ValueError("For 'graph' mode, the outermost network does not support passing "
|
||||
"variable key-value pair parameters.")
|
||||
self.check_construct_args(*inputs, **kwargs)
|
||||
if self.enable_hook:
|
||||
raise ValueError("The graph mode does not support hook function.")
|
||||
out = self.compile_and_run(*inputs)
|
||||
|
|
|
@ -87,7 +87,7 @@ def test_l1_regularizer07():
|
|||
try:
|
||||
l1_regularizer = Net_l1_regularizer(scale)
|
||||
l1_regularizer()
|
||||
except TypeError:
|
||||
except ValueError:
|
||||
assert True
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue