forked from mindspore-Ecosystem/mindspore
ms_function supports optional inputs.
This commit is contained in:
parent
b3d74f7270
commit
6f7d87082a
|
@ -26,7 +26,6 @@ import inspect
|
|||
import importlib
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
|
@ -40,8 +39,7 @@ from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTenso
|
|||
PynativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
|
||||
from mindspore.parallel._tensor import _load_tensor_by_layout
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
|
||||
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, \
|
||||
_to_full_tensor, _get_parameter_broadcast, _get_pipeline_stages
|
||||
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common.mutable import mutable
|
||||
|
@ -109,6 +107,37 @@ def _check_all_tensor(sequence):
|
|||
return True
|
||||
|
||||
|
||||
def _handle_func_args(func, *args, **kwargs):
|
||||
"""Handle the *args and **kwargs inputs of the function."""
|
||||
if kwargs:
|
||||
bound_arguments = inspect.signature(func).bind(*args, **kwargs)
|
||||
bound_arguments.apply_defaults()
|
||||
args = bound_arguments.args
|
||||
kwargs = bound_arguments.kwargs
|
||||
# After apply_defaults, kwargs should be empty here.
|
||||
if kwargs:
|
||||
raise ValueError(f"Failed to handle kwargs of {func.__name__}. Maybe you pass wrong arguments, "
|
||||
f"or there is a key in kwargs that is not used as a function argument, "
|
||||
f"args: {args}, kwargs: {kwargs}")
|
||||
|
||||
positional_args = 0
|
||||
default_args = 0
|
||||
for value in inspect.signature(func).parameters.values():
|
||||
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
return args
|
||||
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
if value.default is inspect.Parameter.empty:
|
||||
positional_args += 1
|
||||
else:
|
||||
default_args += 1
|
||||
if len(args) < positional_args:
|
||||
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
|
||||
if len(args) > positional_args + default_args:
|
||||
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
|
||||
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
|
||||
return args
|
||||
|
||||
|
||||
sys_path = list(sys.path)
|
||||
# Get the entry script path.
|
||||
if sys.argv and sys.argv[0] != '':
|
||||
|
@ -538,7 +567,8 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|||
hash_obj = int(time.time() * 1e9)
|
||||
|
||||
@wraps(func)
|
||||
def staging_specialize(*args):
|
||||
def staging_specialize(*args, **kwargs):
|
||||
args = _handle_func_args(func, *args, **kwargs)
|
||||
process_obj = None
|
||||
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
|
|
|
@ -419,8 +419,9 @@ class Cell(Cell_):
|
|||
def _check_construct_args(self, *inputs, **kwargs):
|
||||
"""Check the args needed by the function construct"""
|
||||
if kwargs:
|
||||
raise ValueError(f"For 'Cell', expect no kwargs here, "
|
||||
"maybe you pass wrong arguments, args: {inputs}, kwargs: {kwargs}")
|
||||
raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
|
||||
f"or there is a key in kwargs that is not used as a function argument. "
|
||||
f"args: {inputs}, kwargs: {kwargs}")
|
||||
positional_args = 0
|
||||
default_args = 0
|
||||
for value in inspect.signature(self.construct).parameters.values():
|
||||
|
|
|
@ -416,3 +416,71 @@ def test_pynative_ms_function_with_tuple_inputs():
|
|||
net = Net()
|
||||
out = net((x, y))
|
||||
assert (out[0].asnumpy() == np.ones([2, 2]) + 1).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_with_optional_inputs():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: PyNative ms_function with optional inputs.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x, y=1):
|
||||
return x + y
|
||||
|
||||
a = Tensor(3, dtype=ms.int32)
|
||||
assert foo(a).asnumpy() == 4
|
||||
assert foo(a, 2).asnumpy() == 5
|
||||
assert foo(a, y=3).asnumpy() == 6
|
||||
assert foo(x=a, y=4).asnumpy() == 7
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_with_args_inputs():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: PyNative ms_function with *args.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x, *args):
|
||||
return x + args[0] + args[1]
|
||||
|
||||
x = Tensor(3, dtype=ms.int32)
|
||||
assert foo(x, 1, 2).asnumpy() == 6
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_with_kwargs_inputs():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: PyNative ms_function with **kwargs.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x, **kwargs):
|
||||
return x + kwargs.get('y')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
x = Tensor(3, dtype=ms.int32)
|
||||
data = {"y": 1}
|
||||
foo(x, **data)
|
||||
|
|
Loading…
Reference in New Issue