forked from mindspore-Ecosystem/mindspore
add sponge frontend
This commit is contained in:
parent
ec1cde222e
commit
5c9bc28798
|
@ -23,6 +23,7 @@ from itertools import repeat, zip_longest
|
|||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
@ -846,6 +847,10 @@ class Validator:
|
|||
"""Returns an empty Tensor."""
|
||||
return Tensor_(dtype, shape)
|
||||
|
||||
@staticmethod
|
||||
def check_type_support(dtype, device, supported_dtypes):
|
||||
return dtype in supported_dtypes or not context.get_context('device_target') == device
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
|
|
|
@ -1312,7 +1312,8 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab
|
|||
>>> print(input_x.sum(axis=1))
|
||||
[10. 35.]
|
||||
"""
|
||||
dtype = x.dtype if dtype is None else dtype
|
||||
input_x = x.astype(mstype.int32) if x.dtype == mstype.bool_ else x
|
||||
dtype = input_x.dtype if dtype is None else dtype
|
||||
if not isinstance(keepdims, int):
|
||||
const_utils.raise_type_error("integer argument expected")
|
||||
if initial is not None and not isinstance(initial, (int, float, bool)):
|
||||
|
@ -1322,14 +1323,14 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab
|
|||
else:
|
||||
axis = check_and_canonicalize_axes(axis, x.ndim)
|
||||
|
||||
if x.dtype == mstype.bool_:
|
||||
x = x.astype("int32")
|
||||
if not check_type_support(input_x.dtype, 'GPU', (mstype.float64, mstype.float32, mstype.float16)):
|
||||
input_x = input_x.astype(mstype.float32)
|
||||
if 0 in x.shape:
|
||||
x = const_utils.make_tensor([0], x.dtype)
|
||||
if keepdims:
|
||||
res = _reduce_sum_keepdims(x, axis)
|
||||
res = _reduce_sum_keepdims(input_x, axis)
|
||||
else:
|
||||
res = _reduce_sum_default(x, axis)
|
||||
res = _reduce_sum_default(input_x, axis)
|
||||
if initial is not None:
|
||||
res += initial
|
||||
return res.astype(dtype)
|
||||
|
@ -1648,6 +1649,7 @@ get_log2_size = constexpr(validator.get_log2_size)
|
|||
check_axis_type = constexpr(validator.check_axis_type)
|
||||
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
|
||||
empty_compile = constexpr(validator.empty_compile)
|
||||
check_type_support = constexpr(validator.check_type_support)
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
|
|
|
@ -1798,7 +1798,8 @@ class Tensor(Tensor_):
|
|||
>>> print(input_x.sum(axis=1))
|
||||
[10. 35.]
|
||||
"""
|
||||
dtype = self.dtype if dtype is None else dtype
|
||||
input_x = self.astype(mstype.int32) if self.dtype == mstype.bool_ else self
|
||||
dtype = input_x.dtype if dtype is None else dtype
|
||||
if not isinstance(keepdims, int):
|
||||
raise TypeError(f"integer argument expected, but got {type(keepdims)}")
|
||||
if initial is not None and not isinstance(initial, (int, float, bool)):
|
||||
|
@ -1808,7 +1809,9 @@ class Tensor(Tensor_):
|
|||
else:
|
||||
axis = validator.check_and_canonicalize_axes(axis, self.ndim)
|
||||
|
||||
input_x = self.astype(mstype.int32) if self.dtype == mstype.bool_ else self
|
||||
if not validator.check_type_support(input_x.dtype, 'GPU',
|
||||
(mstype.float64, mstype.float32, mstype.float16)):
|
||||
input_x = input_x.astype(mstype.float32)
|
||||
if 0 in self.shape:
|
||||
input_x = tensor_operator_registry.get('make_tensor')([0], self.dtype)
|
||||
res = tensor_operator_registry.get('sum')(bool(keepdims))(input_x, axis)
|
||||
|
|
|
@ -89,7 +89,7 @@ def array(obj, dtype=None, copy=True, ndmin=0):
|
|||
_raise_value_error("Empty tensor cannot be expanded beyond the current dimension.")
|
||||
res = _expand(res, ndmin)
|
||||
|
||||
if copy:
|
||||
if copy and isinstance(obj, Tensor):
|
||||
res = copy_(res)
|
||||
elif dtype is not None and dtype != res.dtype:
|
||||
res = res.astype(dtype)
|
||||
|
|
|
@ -116,6 +116,10 @@ bitwise_and = P.BitwiseAnd()
|
|||
bitwise_or = P.BitwiseOr()
|
||||
bitwise_xor = P.BitwiseXor()
|
||||
invert = P.Invert()
|
||||
erf = P.Erf()
|
||||
erfc = P.Erfc()
|
||||
sort = P.Sort()
|
||||
tensor_range = P.Range()
|
||||
|
||||
scalar_to_array = P.ScalarToArray()
|
||||
scalar_to_tensor = P.ScalarToTensor()
|
||||
|
|
Loading…
Reference in New Issue