support abs and mean of Tensor

This commit is contained in:
Bairong 2020-11-27 19:36:12 +08:00
parent 3874160faf
commit 623b2e3f99
6 changed files with 172 additions and 0 deletions

View File

@ -30,6 +30,27 @@ trans = P.Transpose()
shape_ = P.Shape()
reshape_ = P.Reshape()
dtype_ = P.DType()
abs_ = P.Abs()
def mean(x, axis=(), keep_dims=False):
"""
Reduce a dimension of a tensor by averaging all elements in the dimension.
Args:
axis (Union[None, int, tuple(int)]): Dimensions of reduction,
when axis is None or empty tuple, reduce all dimensions.
Default: (), reduce all dimensions.
keep_dims (bool): Whether to keep the reduced dimensions.
Default : False, don't keep these reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
if axis is None:
axis = ()
reduce_mean = P.ReduceMean(keep_dims)
return reduce_mean(x, axis)
def all_(x, axis=(), keep_dims=False):
"""

View File

@ -152,6 +152,8 @@ BuiltInTypeMap &GetMethodMap() {
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"abs", std::string("abs_")}, // C.abs_
{"mean", std::string("mean")}, // C.mean
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod

View File

@ -325,6 +325,35 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('broadcast_to')(x.shape)(self)
def abs(self):
"""
Return absolute value element-wisely.
Returns:
Tensor, has the same data type as x.
"""
return tensor_operator_registry.get('abs')()(self)
def mean(self, axis=(), keep_dims=False):
"""
Reduce a dimension of a tensor by averaging all elements in the dimension.
Args:
axis (Union[None, int, tuple(int)]): Dimensions of reduction,
when axis is None or empty tuple, reduce all dimensions.
Default: (), reduce all dimensions.
keep_dims (bool): Whether to keep the reduced dimensions.
Default : False, don't keep these reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
if axis is None:
axis = ()
return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
class RowTensor:
"""
A sparse representation of a set of tensor slices at given indices.

View File

@ -173,6 +173,8 @@ tensor_operator_registry.register('__pow__', tensor_pow)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('abs', P.Abs)
tensor_operator_registry.register('mean', P.ReduceMean)
tensor_operator_registry.register('reshape', P.Reshape)
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
# ms cannot support Tensor(True) compare

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_abs """
import mindspore as ms
from mindspore import nn
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_abs():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = ms.Tensor([1, -2, 3])
def construct(self):
return self.value.abs()
net = Net()
net()
def test_abs_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
return x.abs()
net = Net()
x = ms.Tensor([1, -2, 3])
net(x)

View File

@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_mean """
import mindspore as ms
from mindspore import nn
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_mean():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.value = ms.Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
def construct(self):
return self.value.mean()
net = Net()
net()
def test_mean_axis():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.value = ms.Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
def construct(self):
return self.value.mean(axis=1)
net = Net()
net()
def test_mean_parameter():
class Net(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
return x.mean()
x = ms.Tensor([[1, 2, 3], [1, 2, 3]], dtype=ms.float32)
net = Net()
net(x)
def test_mean_parameter_axis():
class Net(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
return x.mean(axis=1)
x = ms.Tensor([[1, 2, 3], [1, 2, 3]], dtype=ms.float32)
net = Net()
net(x)