support interface 'all' and 'any' of tensor

This commit is contained in:
buxue 2020-08-04 09:47:42 +08:00
parent b0b4fa08b3
commit 2c4cb49a11
10 changed files with 185 additions and 9 deletions

View File

@ -31,6 +31,41 @@ trans = P.Transpose()
shape_ = P.Shape()
dtype_ = P.DType()
def all_(x, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
x (Tensor): A Tensor to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
reduce_all = P.ReduceAll(keep_dims)
return reduce_all(x, axis)
def any_(x, axis=(), keep_dims=False):
"""
Check any array element along a given axis evaluate to True.
Args:
x (Tensor): A Tensor to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
reduce_any = P.ReduceAny(keep_dims)
return reduce_any(x, axis)
def transpose(x):
"""Implementation of `transpose`."""
shape = F.shape(x)
@ -157,7 +192,6 @@ def check_is_const_int(x, op_name, arg_name):
return True
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
@ -316,4 +350,5 @@ def to_array(x):
"""Implementation of `to_array`."""
return x.__ms_to_array__()
tensor_operator_registry.register('__bool__', tensor_bool)

View File

@ -143,6 +143,8 @@ BuiltInTypeMap &GetMethodMap() {
}},
{kObjectTypeTensorType,
{
{"all", std::string("all_")}, // C.reduce_all
{"any", std::string("any_")}, // C.reduce_any
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul

View File

@ -35,9 +35,11 @@ class Registry(UserDict):
new_args = list(args)
new_args.append(obj_str)
return self["vm_compare"](*new_args)
obj = wrap
else:
obj = self[obj_str]
return obj
tensor_operator_registry = Registry()

View File

@ -27,7 +27,6 @@ np_types = (np.int8, np.int16, np.int32, np.int64,
np.float32, np.float64, np.bool_)
class Tensor(Tensor_):
"""
Tensor for data storage.
@ -205,6 +204,34 @@ class Tensor(Tensor_):
return "Unknown Tensor type!"
return str(self.asnumpy())
def all(self, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
def any(self, axis=(), keep_dims=False):
"""
Check any array element along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
Tensor, has the same data type as x.
"""
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
@ -257,6 +284,7 @@ class IndexedSlices:
>>> values = Tensor([[1, 2]], dtype=ms.float32)
>>> Net((3, 2))(indices, values)
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
@ -297,5 +325,6 @@ class SparseTensor:
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> Net((3, 4))(indices, values)
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError

View File

@ -30,7 +30,6 @@ dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.add_prim_attr('const_value', True)
issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance()
fill = P.Fill()
@ -67,6 +66,7 @@ assign_sub = P.AssignSub()
assign = P.Assign()
square = P.Square()
sqrt = P.Sqrt()
scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray()
@ -83,7 +83,6 @@ partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
list_getitem = Primitive('list_getitem')
@ -102,7 +101,6 @@ tuple_equal = Primitive("tuple_equal")
list_equal = Primitive("list_equal")
make_ref = Primitive("make_ref")
scalar_add = Primitive('scalar_add')
scalar_mul = Primitive('scalar_mul')
scalar_sub = Primitive('scalar_sub')
@ -154,7 +152,6 @@ shape_mul = Primitive("shape_mul")
# a primitive to compare between tuple.
stop_gradient = Primitive("stop_gradient")
make_indexed_slices = Primitive('MakeIndexedSlices')
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
@ -172,7 +169,9 @@ tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__pow__', tensor_pow)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
tensor_operator_registry.register('__neg__', neg_tensor)
@ -181,6 +180,6 @@ tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
# support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)

View File

@ -1111,6 +1111,7 @@ class Mul(_MathBinaryOp):
>>> mul(input_x, input_y)
[4, 10, 18]
"""
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test interface 'all' and 'any' of tensor """
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
def test_all_and_any_of_tensor_in_graph():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
all_ = x.all()
any_ = x.any()
all_0 = x.all(0, True)
any_0 = x.any(0, True)
return all_, any_, all_0, any_0
net = Net()
x = Tensor(np.array([[True, False, False], [True, False, False]]))
context.set_context(mode=context.GRAPH_MODE)
net(x)
def test_all_and_any_of_tensor_in_pynative():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
all_ = x.all()
any_ = x.any()
all_0 = x.all(0, True)
any_0 = x.any(0, True)
return all_, any_, all_0, any_0
net = Net()
x = Tensor(np.array([[True, False, True], [True, False, False]]))
context.set_context(mode=context.PYNATIVE_MODE)
ret = net(x)
assert ret[0].asnumpy() == np.array(False)
assert ret[1].asnumpy() == np.array(True)
assert ret[2].asnumpy().shape == np.array([[True, False, False]]).shape
assert (ret[2].asnumpy() == np.array([[True, False, False]])).all()
assert ret[3].shape == Tensor(np.array([[True, False, True]])).shape
assert (ret[3] == Tensor(np.array([[True, False, True]]))).all()

View File

@ -194,7 +194,19 @@ def vm_impl_all(self):
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.all(x, axis)
out = vm.all(x, axis, self.keep_dims)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.ReduceAny)
def vm_impl_any(self):
"""Generate vm_impl function for Any"""
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.any(x, axis, self.keep_dims)
return Tensor(out)
return vm_impl

View File

@ -67,3 +67,5 @@ setattr(vm, "tanh", tanh)
setattr(vm, "sigmoid", sigmoid)
setattr(vm, 'maximum', maximum)
setattr(vm, 'minimum', minimum)
setattr(vm, 'all', all_)
setattr(vm, 'any', any_)

View File

@ -840,3 +840,35 @@ def minimum(x, y):
numpy.ndarray, has the same type as x.
"""
return np.minimum(x, y)
def all_(x, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
x (numpy.ndarray): An array to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
numpy.ndarray, has the same type as x.
"""
axis = None if axis == () else axis
return np.all(x, axis, keepdims=keep_dims)
def any_(x, axis=(), keep_dims=False):
"""
Check any array element along a given axis evaluate to True.
Args:
x (numpy.ndarray): An array to be reduced.
axis (Union[None, int, tuple(int)): Dimensions of reduction.
keep_dims (bool): Whether to keep the reduced dimensions.
Returns:
numpy.ndarray, has the same type as x.
"""
axis = None if axis == () else axis
return np.any(x, axis, keepdims=keep_dims)