add sponge frontend

This commit is contained in:
huangmengxi 2021-08-04 10:23:23 +08:00
parent ec1cde222e
commit 5c9bc28798
5 changed files with 22 additions and 8 deletions

View File

@ -23,6 +23,7 @@ from itertools import repeat, zip_longest
from collections import deque
from collections.abc import Iterable
import numpy as np
from mindspore import context
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore._c_expression import Tensor as Tensor_
@ -846,6 +847,10 @@ class Validator:
"""Returns an empty Tensor."""
return Tensor_(dtype, shape)
@staticmethod
def check_type_support(dtype, device, supported_dtypes):
return dtype in supported_dtypes or not context.get_context('device_target') == device
def check_input_format(input_param):
"""Judge input format."""

View File

@ -1312,7 +1312,8 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab
>>> print(input_x.sum(axis=1))
[10. 35.]
"""
dtype = x.dtype if dtype is None else dtype
input_x = x.astype(mstype.int32) if x.dtype == mstype.bool_ else x
dtype = input_x.dtype if dtype is None else dtype
if not isinstance(keepdims, int):
const_utils.raise_type_error("integer argument expected")
if initial is not None and not isinstance(initial, (int, float, bool)):
@ -1322,14 +1323,14 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab
else:
axis = check_and_canonicalize_axes(axis, x.ndim)
if x.dtype == mstype.bool_:
x = x.astype("int32")
if not check_type_support(input_x.dtype, 'GPU', (mstype.float64, mstype.float32, mstype.float16)):
input_x = input_x.astype(mstype.float32)
if 0 in x.shape:
x = const_utils.make_tensor([0], x.dtype)
if keepdims:
res = _reduce_sum_keepdims(x, axis)
res = _reduce_sum_keepdims(input_x, axis)
else:
res = _reduce_sum_default(x, axis)
res = _reduce_sum_default(input_x, axis)
if initial is not None:
res += initial
return res.astype(dtype)
@ -1648,6 +1649,7 @@ get_log2_size = constexpr(validator.get_log2_size)
check_axis_type = constexpr(validator.check_axis_type)
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
empty_compile = constexpr(validator.empty_compile)
check_type_support = constexpr(validator.check_type_support)
def tensor_bool(x):

View File

@ -1798,7 +1798,8 @@ class Tensor(Tensor_):
>>> print(input_x.sum(axis=1))
[10. 35.]
"""
dtype = self.dtype if dtype is None else dtype
input_x = self.astype(mstype.int32) if self.dtype == mstype.bool_ else self
dtype = input_x.dtype if dtype is None else dtype
if not isinstance(keepdims, int):
raise TypeError(f"integer argument expected, but got {type(keepdims)}")
if initial is not None and not isinstance(initial, (int, float, bool)):
@ -1808,7 +1809,9 @@ class Tensor(Tensor_):
else:
axis = validator.check_and_canonicalize_axes(axis, self.ndim)
input_x = self.astype(mstype.int32) if self.dtype == mstype.bool_ else self
if not validator.check_type_support(input_x.dtype, 'GPU',
(mstype.float64, mstype.float32, mstype.float16)):
input_x = input_x.astype(mstype.float32)
if 0 in self.shape:
input_x = tensor_operator_registry.get('make_tensor')([0], self.dtype)
res = tensor_operator_registry.get('sum')(bool(keepdims))(input_x, axis)

View File

@ -89,7 +89,7 @@ def array(obj, dtype=None, copy=True, ndmin=0):
_raise_value_error("Empty tensor cannot be expanded beyond the current dimension.")
res = _expand(res, ndmin)
if copy:
if copy and isinstance(obj, Tensor):
res = copy_(res)
elif dtype is not None and dtype != res.dtype:
res = res.astype(dtype)

View File

@ -116,6 +116,10 @@ bitwise_and = P.BitwiseAnd()
bitwise_or = P.BitwiseOr()
bitwise_xor = P.BitwiseXor()
invert = P.Invert()
erf = P.Erf()
erfc = P.Erfc()
sort = P.Sort()
tensor_range = P.Range()
scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()