forked from OSSInnovation/mindspore
!481 Move args_type_check function to _checkparam.py
Merge pull request !481 from leonwanghui/rm-code
This commit is contained in:
commit
862d23fe90
|
@ -14,8 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Check parameters."""
|
"""Check parameters."""
|
||||||
import re
|
import re
|
||||||
|
import inspect
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import reduce
|
from functools import reduce, wraps
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
@ -181,7 +182,7 @@ class Validator:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_subclass(arg_name, type_, template_type, prim_name):
|
def check_subclass(arg_name, type_, template_type, prim_name):
|
||||||
"""Checks whether some type is sublcass of another type"""
|
"""Checks whether some type is subclass of another type"""
|
||||||
if not isinstance(template_type, Iterable):
|
if not isinstance(template_type, Iterable):
|
||||||
template_type = (template_type,)
|
template_type = (template_type,)
|
||||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||||
|
@ -240,7 +241,6 @@ class Validator:
|
||||||
elem_types = map(_check_tensor_type, args.items())
|
elem_types = map(_check_tensor_type, args.items())
|
||||||
reduce(_check_types_same, elem_types)
|
reduce(_check_types_same, elem_types)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
||||||
"""
|
"""
|
||||||
|
@ -261,7 +261,7 @@ class Validator:
|
||||||
def _check_types_same(arg1, arg2):
|
def _check_types_same(arg1, arg2):
|
||||||
arg1_name, arg1_type = arg1
|
arg1_name, arg1_type = arg1
|
||||||
arg2_name, arg2_type = arg2
|
arg2_name, arg2_type = arg2
|
||||||
excp_flag = False
|
except_flag = False
|
||||||
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
||||||
arg1_type = arg1_type.element_type()
|
arg1_type = arg1_type.element_type()
|
||||||
arg2_type = arg2_type.element_type()
|
arg2_type = arg2_type.element_type()
|
||||||
|
@ -271,9 +271,9 @@ class Validator:
|
||||||
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
|
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
|
||||||
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
|
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
|
||||||
else:
|
else:
|
||||||
excp_flag = True
|
except_flag = True
|
||||||
|
|
||||||
if excp_flag or arg1_type != arg2_type:
|
if except_flag or arg1_type != arg2_type:
|
||||||
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||||
return arg1
|
return arg1
|
||||||
|
@ -283,11 +283,12 @@ class Validator:
|
||||||
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
||||||
"""Checks whether a value is instance of some types."""
|
"""Checks whether a value is instance of some types."""
|
||||||
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
||||||
|
|
||||||
def raise_error_msg():
|
def raise_error_msg():
|
||||||
"""func for raising error message when check failed"""
|
"""func for raising error message when check failed"""
|
||||||
type_names = [t.__name__ for t in valid_types]
|
type_names = [t.__name__ for t in valid_types]
|
||||||
num_types = len(valid_types)
|
num_types = len(valid_types)
|
||||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
||||||
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||||
|
|
||||||
|
@ -303,6 +304,7 @@ class Validator:
|
||||||
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
||||||
"""Checks whether a type in some specified types"""
|
"""Checks whether a type in some specified types"""
|
||||||
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
||||||
|
|
||||||
def get_typename(t):
|
def get_typename(t):
|
||||||
return t.__name__ if hasattr(t, '__name__') else str(t)
|
return t.__name__ if hasattr(t, '__name__') else str(t)
|
||||||
|
|
||||||
|
@ -368,9 +370,9 @@ class ParamValidator:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_isinstance(arg_name, arg_value, classes):
|
def check_isinstance(arg_name, arg_value, classes):
|
||||||
"""Check arg isintance of classes"""
|
"""Check arg isinstance of classes"""
|
||||||
if not isinstance(arg_value, classes):
|
if not isinstance(arg_value, classes):
|
||||||
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
|
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||||
return arg_value
|
return arg_value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -384,7 +386,7 @@ class ParamValidator:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
||||||
"""Check whether some type is sublcass of another type"""
|
"""Check whether some type is subclass of another type"""
|
||||||
if not isinstance(template_type, Iterable):
|
if not isinstance(template_type, Iterable):
|
||||||
template_type = (template_type,)
|
template_type = (template_type,)
|
||||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||||
|
@ -402,9 +404,9 @@ class ParamValidator:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_bool(arg_name, arg_value):
|
def check_bool(arg_name, arg_value):
|
||||||
"""Check arg isintance of bool"""
|
"""Check arg isinstance of bool"""
|
||||||
if not isinstance(arg_value, bool):
|
if not isinstance(arg_value, bool):
|
||||||
raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.')
|
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
||||||
return arg_value
|
return arg_value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII):
|
||||||
if re.match(reg, target, flag) is None:
|
if re.match(reg, target, flag) is None:
|
||||||
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
|
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def args_type_check(*type_args, **type_kwargs):
|
||||||
|
"""Check whether input data type is correct."""
|
||||||
|
|
||||||
|
def type_check(func):
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
nonlocal bound_types
|
||||||
|
bound_values = sig.bind(*args, **kwargs)
|
||||||
|
argument_dict = bound_values.arguments
|
||||||
|
if "kwargs" in bound_types:
|
||||||
|
bound_types = bound_types["kwargs"]
|
||||||
|
if "kwargs" in argument_dict:
|
||||||
|
argument_dict = argument_dict["kwargs"]
|
||||||
|
for name, value in argument_dict.items():
|
||||||
|
if name in bound_types:
|
||||||
|
if value is not None and not isinstance(value, bound_types[name]):
|
||||||
|
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return type_check
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""
|
"""
|
||||||
Extension functions.
|
Extension functions.
|
||||||
|
|
||||||
Python functions that will be called in the c++ parts of MindSpore.
|
Python functions that will be called in the c++ parts of MindSpore.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Pynative mode help module."""
|
|
||||||
from inspect import signature
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
|
|
||||||
def args_type_check(*type_args, **type_kwargs):
|
|
||||||
"""Check whether input data type is correct."""
|
|
||||||
|
|
||||||
def type_check(func):
|
|
||||||
sig = signature(func)
|
|
||||||
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
nonlocal bound_types
|
|
||||||
bound_values = sig.bind(*args, **kwargs)
|
|
||||||
argument_dict = bound_values.arguments
|
|
||||||
if "kwargs" in bound_types:
|
|
||||||
bound_types = bound_types["kwargs"]
|
|
||||||
if "kwargs" in argument_dict:
|
|
||||||
argument_dict = argument_dict["kwargs"]
|
|
||||||
for name, value in argument_dict.items():
|
|
||||||
if name in bound_types:
|
|
||||||
if value is not None and not isinstance(value, bound_types[name]):
|
|
||||||
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return type_check
|
|
|
@ -14,7 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""
|
"""
|
||||||
The context of mindspore, used to configure the current execution environment,
|
The context of mindspore, used to configure the current execution environment,
|
||||||
including execution mode, execution backend and other feature switchs.
|
including execution mode, execution backend and other feature switches.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
@ -22,7 +22,7 @@ from collections import namedtuple
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._c_expression import MSContext
|
from mindspore._c_expression import MSContext
|
||||||
from mindspore._extends.pynative_helper import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||||
_reset_auto_parallel_context
|
_reset_auto_parallel_context
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def _make_directory(path: str):
|
||||||
"""Make directory."""
|
"""Make directory."""
|
||||||
real_path = None
|
real_path = None
|
||||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||||
raise ValueError(f"Input path `{path}` is invaild type")
|
raise ValueError(f"Input path `{path}` is invalid type")
|
||||||
|
|
||||||
# convert the relative paths
|
# convert the relative paths
|
||||||
path = os.path.realpath(path)
|
path = os.path.realpath(path)
|
||||||
|
@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local):
|
||||||
"""
|
"""
|
||||||
Thread local Info used for store thread local attributes.
|
Thread local Info used for store thread local attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(_ThreadLocalInfo, self).__init__()
|
super(_ThreadLocalInfo, self).__init__()
|
||||||
self._reserve_class_name_in_scope = True
|
self._reserve_class_name_in_scope = True
|
||||||
|
@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local):
|
||||||
Args:
|
Args:
|
||||||
is_pynative (bool): Whether to adopt the PyNative mode.
|
is_pynative (bool): Whether to adopt the PyNative mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, is_pynative):
|
def __init__(self, is_pynative):
|
||||||
super(_ContextSwitchInfo, self).__init__()
|
super(_ContextSwitchInfo, self).__init__()
|
||||||
self.context_stack = []
|
self.context_stack = []
|
||||||
|
@ -209,7 +211,7 @@ class _Context:
|
||||||
def device_target(self, target):
|
def device_target(self, target):
|
||||||
success = self._context_handle.set_device_target(target)
|
success = self._context_handle.set_device_target(target)
|
||||||
if not success:
|
if not success:
|
||||||
raise ValueError("target device name is invalid!!!")
|
raise ValueError("Target device name is invalid!!!")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_id(self):
|
def device_id(self):
|
||||||
|
@ -335,7 +337,7 @@ class _Context:
|
||||||
|
|
||||||
@graph_memory_max_size.setter
|
@graph_memory_max_size.setter
|
||||||
def graph_memory_max_size(self, graph_memory_max_size):
|
def graph_memory_max_size(self, graph_memory_max_size):
|
||||||
if check_input_fotmat(graph_memory_max_size):
|
if check_input_format(graph_memory_max_size):
|
||||||
graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
|
graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
|
||||||
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
|
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
|
||||||
else:
|
else:
|
||||||
|
@ -347,7 +349,7 @@ class _Context:
|
||||||
|
|
||||||
@variable_memory_max_size.setter
|
@variable_memory_max_size.setter
|
||||||
def variable_memory_max_size(self, variable_memory_max_size):
|
def variable_memory_max_size(self, variable_memory_max_size):
|
||||||
if check_input_fotmat(variable_memory_max_size):
|
if check_input_format(variable_memory_max_size):
|
||||||
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
|
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
|
||||||
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
|
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
|
||||||
else:
|
else:
|
||||||
|
@ -367,12 +369,13 @@ class _Context:
|
||||||
thread_info.debug_runtime = enable
|
thread_info.debug_runtime = enable
|
||||||
|
|
||||||
|
|
||||||
def check_input_fotmat(x):
|
def check_input_format(x):
|
||||||
import re
|
import re
|
||||||
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
||||||
result = re.match(pattern, x)
|
result = re.match(pattern, x)
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
_k_context = None
|
_k_context = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ import threading
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||||
from mindspore._c_expression import AutoParallelContext
|
from mindspore._c_expression import AutoParallelContext
|
||||||
from mindspore._extends.pynative_helper import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
|
|
||||||
|
|
||||||
class _AutoParallelContext:
|
class _AutoParallelContext:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Context of cost_model in auto_parallel"""
|
"""Context of cost_model in auto_parallel"""
|
||||||
import threading
|
import threading
|
||||||
from mindspore._c_expression import CostModelContext
|
from mindspore._c_expression import CostModelContext
|
||||||
from mindspore._extends.pynative_helper import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
|
|
||||||
|
|
||||||
class _CostModelContext:
|
class _CostModelContext:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from mindspore._c_expression import CostModelContext
|
from mindspore._c_expression import CostModelContext
|
||||||
from mindspore._extends.pynative_helper import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
|
|
||||||
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||||
|
|
||||||
|
|
|
@ -14,16 +14,13 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test_backend """
|
""" test_backend """
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context, ms_function
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore._extends.pynative_helper import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
from mindspore.common.tensor import Tensor
|
|
||||||
from mindspore.common.api import ms_function
|
|
||||||
|
|
||||||
|
|
||||||
def setup_module(module):
|
def setup_module(module):
|
||||||
|
@ -32,6 +29,7 @@ def setup_module(module):
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
""" Net definition """
|
""" Net definition """
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
|
@ -50,6 +48,7 @@ def test_vm_backend():
|
||||||
output = add()
|
output = add()
|
||||||
assert output.asnumpy().shape == (1, 3, 3, 4)
|
assert output.asnumpy().shape == (1, 3, 3, 4)
|
||||||
|
|
||||||
|
|
||||||
def test_vm_set_context():
|
def test_vm_set_context():
|
||||||
""" test_vm_set_context """
|
""" test_vm_set_context """
|
||||||
context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE)
|
context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE)
|
||||||
|
@ -59,6 +58,7 @@ def test_vm_set_context():
|
||||||
assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0
|
assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(v_str=str, v_int=int, v_tuple=tuple)
|
@args_type_check(v_str=str, v_int=int, v_tuple=tuple)
|
||||||
def check_input(v_str, v_int, v_tuple):
|
def check_input(v_str, v_int, v_tuple):
|
||||||
""" check_input """
|
""" check_input """
|
||||||
|
|
Loading…
Reference in New Issue