ms_function supports optional inputs.

This commit is contained in:
huangbingjian 2022-08-15 14:19:20 +08:00
parent b3d74f7270
commit 6f7d87082a
3 changed files with 105 additions and 6 deletions

View File

@ -26,7 +26,6 @@ import inspect
import importlib
from collections import OrderedDict
from functools import wraps
import numpy as np
import mindspore as ms
from mindspore import context
from mindspore import log as logger
@ -40,8 +39,7 @@ from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTenso
PynativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from mindspore.parallel._tensor import _load_tensor_by_layout
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, \
_to_full_tensor, _get_parameter_broadcast, _get_pipeline_stages
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages
from mindspore._checkparam import Validator
from mindspore.common._utils import is_shape_unknown
from mindspore.common.mutable import mutable
@ -109,6 +107,37 @@ def _check_all_tensor(sequence):
return True
def _handle_func_args(func, *args, **kwargs):
"""Handle the *args and **kwargs inputs of the function."""
if kwargs:
bound_arguments = inspect.signature(func).bind(*args, **kwargs)
bound_arguments.apply_defaults()
args = bound_arguments.args
kwargs = bound_arguments.kwargs
# After apply_defaults, kwargs should be empty here.
if kwargs:
raise ValueError(f"Failed to handle kwargs of {func.__name__}. Maybe you pass wrong arguments, "
f"or there is a key in kwargs that is not used as a function argument, "
f"args: {args}, kwargs: {kwargs}")
positional_args = 0
default_args = 0
for value in inspect.signature(func).parameters.values():
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
return args
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
if value.default is inspect.Parameter.empty:
positional_args += 1
else:
default_args += 1
if len(args) < positional_args:
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
if len(args) > positional_args + default_args:
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
return args
sys_path = list(sys.path)
# Get the entry script path.
if sys.argv and sys.argv[0] != '':
@ -538,7 +567,8 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
hash_obj = int(time.time() * 1e9)
@wraps(func)
def staging_specialize(*args):
def staging_specialize(*args, **kwargs):
args = _handle_func_args(func, *args, **kwargs)
process_obj = None
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]

View File

@ -419,8 +419,9 @@ class Cell(Cell_):
def _check_construct_args(self, *inputs, **kwargs):
"""Check the args needed by the function construct"""
if kwargs:
raise ValueError(f"For 'Cell', expect no kwargs here, "
"maybe you pass wrong arguments, args: {inputs}, kwargs: {kwargs}")
raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
f"or there is a key in kwargs that is not used as a function argument. "
f"args: {inputs}, kwargs: {kwargs}")
positional_args = 0
default_args = 0
for value in inspect.signature(self.construct).parameters.values():

View File

@ -416,3 +416,71 @@ def test_pynative_ms_function_with_tuple_inputs():
net = Net()
out = net((x, y))
assert (out[0].asnumpy() == np.ones([2, 2]) + 1).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_with_optional_inputs():
"""
Feature: PyNative ms_function.
Description: PyNative ms_function with optional inputs.
Expectation: The calculation result is correct.
"""
@ms_function
def foo(x, y=1):
return x + y
a = Tensor(3, dtype=ms.int32)
assert foo(a).asnumpy() == 4
assert foo(a, 2).asnumpy() == 5
assert foo(a, y=3).asnumpy() == 6
assert foo(x=a, y=4).asnumpy() == 7
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_with_args_inputs():
"""
Feature: PyNative ms_function.
Description: PyNative ms_function with *args.
Expectation: The calculation result is correct.
"""
@ms_function
def foo(x, *args):
return x + args[0] + args[1]
x = Tensor(3, dtype=ms.int32)
assert foo(x, 1, 2).asnumpy() == 6
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_with_kwargs_inputs():
"""
Feature: PyNative ms_function.
Description: PyNative ms_function with **kwargs.
Expectation: No exception.
"""
@ms_function
def foo(x, **kwargs):
return x + kwargs.get('y')
with pytest.raises(ValueError):
x = Tensor(3, dtype=ms.int32)
data = {"y": 1}
foo(x, **data)