[dfx]optimize exception info print

This commit is contained in:
huanghui 2021-05-31 10:08:05 +08:00
parent 209a6e1df6
commit 5acf3c8ade
2 changed files with 27 additions and 4 deletions

View File

@ -327,6 +327,31 @@ class Cell(Cell_):
_pynative_exec.leave_construct(self)
return output
def check_construct_args(self, *inputs, **kwargs):
"""Check the args needed by the function construct"""
if kwargs:
raise ValueError("For 'graph' mode, the outermost network does not support passing "
"variable key-value pair parameters.")
positional_args = 0
default_args = 0
for value in inspect.signature(self.construct).parameters.values():
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
return
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
if value.default is inspect.Parameter.empty:
positional_args += 1
else:
default_args += 1
if len(inputs) < positional_args:
raise ValueError(
f"The function construct need {positional_args} positional argument, but only provided {len(inputs)}.")
if len(inputs) > positional_args + default_args:
raise ValueError(
f"The function construct need {positional_args} positional argument and {default_args} default "
f"argument, but provided {len(inputs)}")
def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
@ -336,9 +361,7 @@ class Cell(Cell_):
inputs = bound_args.args
kwargs = bound_args.kwargs
if context.get_context("mode") == context.GRAPH_MODE:
if kwargs:
raise ValueError("For 'graph' mode, the outermost network does not support passing "
"variable key-value pair parameters.")
self.check_construct_args(*inputs, **kwargs)
if self.enable_hook:
raise ValueError("The graph mode does not support hook function.")
out = self.compile_and_run(*inputs)

View File

@ -87,7 +87,7 @@ def test_l1_regularizer07():
try:
l1_regularizer = Net_l1_regularizer(scale)
l1_regularizer()
except TypeError:
except ValueError:
assert True