support interface 'all' and 'any' of tensor
This commit is contained in:
parent
b0b4fa08b3
commit
2c4cb49a11
|
@ -31,6 +31,41 @@ trans = P.Transpose()
|
|||
shape_ = P.Shape()
|
||||
dtype_ = P.DType()
|
||||
|
||||
|
||||
def all_(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check all array elements along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be reduced.
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
|
||||
reduce_all = P.ReduceAll(keep_dims)
|
||||
return reduce_all(x, axis)
|
||||
|
||||
|
||||
def any_(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check any array element along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be reduced.
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
|
||||
reduce_any = P.ReduceAny(keep_dims)
|
||||
return reduce_any(x, axis)
|
||||
|
||||
|
||||
def transpose(x):
|
||||
"""Implementation of `transpose`."""
|
||||
shape = F.shape(x)
|
||||
|
@ -157,7 +192,6 @@ def check_is_const_int(x, op_name, arg_name):
|
|||
return True
|
||||
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tensor_bool_cond(shp):
|
||||
"""check if tensor is a bool condition"""
|
||||
|
@ -316,4 +350,5 @@ def to_array(x):
|
|||
"""Implementation of `to_array`."""
|
||||
return x.__ms_to_array__()
|
||||
|
||||
|
||||
tensor_operator_registry.register('__bool__', tensor_bool)
|
||||
|
|
|
@ -143,6 +143,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"all", std::string("all_")}, // C.reduce_all
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
|
|
|
@ -35,9 +35,11 @@ class Registry(UserDict):
|
|||
new_args = list(args)
|
||||
new_args.append(obj_str)
|
||||
return self["vm_compare"](*new_args)
|
||||
|
||||
obj = wrap
|
||||
else:
|
||||
obj = self[obj_str]
|
||||
return obj
|
||||
|
||||
|
||||
tensor_operator_registry = Registry()
|
||||
|
|
|
@ -27,7 +27,6 @@ np_types = (np.int8, np.int16, np.int32, np.int64,
|
|||
np.float32, np.float64, np.bool_)
|
||||
|
||||
|
||||
|
||||
class Tensor(Tensor_):
|
||||
"""
|
||||
Tensor for data storage.
|
||||
|
@ -205,6 +204,34 @@ class Tensor(Tensor_):
|
|||
return "Unknown Tensor type!"
|
||||
return str(self.asnumpy())
|
||||
|
||||
def all(self, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check all array elements along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
|
||||
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
|
||||
|
||||
def any(self, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check any array element along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
|
||||
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
|
||||
|
||||
@property
|
||||
def virtual_flag(self):
|
||||
"""Mark tensor is virtual."""
|
||||
|
@ -257,6 +284,7 @@ class IndexedSlices:
|
|||
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
>>> Net((3, 2))(indices, values)
|
||||
"""
|
||||
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -297,5 +325,6 @@ class SparseTensor:
|
|||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> Net((3, 4))(indices, values)
|
||||
"""
|
||||
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -30,7 +30,6 @@ dtype = P.DType()
|
|||
isconstant = Primitive('is_constant')
|
||||
isconstant.add_prim_attr('const_value', True)
|
||||
|
||||
|
||||
issubclass_ = P.IsSubClass()
|
||||
isinstance_ = P.IsInstance()
|
||||
fill = P.Fill()
|
||||
|
@ -67,6 +66,7 @@ assign_sub = P.AssignSub()
|
|||
assign = P.Assign()
|
||||
square = P.Square()
|
||||
sqrt = P.Sqrt()
|
||||
|
||||
scalar_to_array = P.ScalarToArray()
|
||||
scalar_to_tensor = P.ScalarToTensor()
|
||||
tuple_to_array = P.TupleToArray()
|
||||
|
@ -83,7 +83,6 @@ partial = P.Partial()
|
|||
# depend: mount a node to another node
|
||||
depend = P.Depend()
|
||||
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
list_getitem = Primitive('list_getitem')
|
||||
|
@ -102,7 +101,6 @@ tuple_equal = Primitive("tuple_equal")
|
|||
list_equal = Primitive("list_equal")
|
||||
make_ref = Primitive("make_ref")
|
||||
|
||||
|
||||
scalar_add = Primitive('scalar_add')
|
||||
scalar_mul = Primitive('scalar_mul')
|
||||
scalar_sub = Primitive('scalar_sub')
|
||||
|
@ -154,7 +152,6 @@ shape_mul = Primitive("shape_mul")
|
|||
# a primitive to compare between tuple.
|
||||
stop_gradient = Primitive("stop_gradient")
|
||||
|
||||
|
||||
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
|
@ -172,7 +169,9 @@ tensor_operator_registry.register('__truediv__', tensor_div)
|
|||
tensor_operator_registry.register('__mod__', tensor_mod)
|
||||
tensor_operator_registry.register('__pow__', tensor_pow)
|
||||
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
|
||||
#ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('all', P.ReduceAll)
|
||||
tensor_operator_registry.register('any', P.ReduceAny)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
tensor_operator_registry.register('__neg__', neg_tensor)
|
||||
|
@ -181,6 +180,6 @@ tensor_operator_registry.register('__le__', tensor_le)
|
|||
tensor_operator_registry.register('__gt__', tensor_gt)
|
||||
tensor_operator_registry.register('__ge__', tensor_ge)
|
||||
tensor_operator_registry.register('shape', shape)
|
||||
#support GE backend for no compare operators
|
||||
# support GE backend for no compare operators
|
||||
tensor_operator_registry.register('vm_compare', BP.vm_compare)
|
||||
tensor_operator_registry.register('cast', cast)
|
||||
|
|
|
@ -1111,6 +1111,7 @@ class Mul(_MathBinaryOp):
|
|||
>>> mul(input_x, input_y)
|
||||
[4, 10, 18]
|
||||
"""
|
||||
|
||||
def infer_value(self, x, y):
|
||||
if x is not None and y is not None:
|
||||
x = x.asnumpy()
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test interface 'all' and 'any' of tensor """
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
def test_all_and_any_of_tensor_in_graph():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
all_ = x.all()
|
||||
any_ = x.any()
|
||||
all_0 = x.all(0, True)
|
||||
any_0 = x.any(0, True)
|
||||
return all_, any_, all_0, any_0
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.array([[True, False, False], [True, False, False]]))
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net(x)
|
||||
|
||||
|
||||
def test_all_and_any_of_tensor_in_pynative():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
all_ = x.all()
|
||||
any_ = x.any()
|
||||
all_0 = x.all(0, True)
|
||||
any_0 = x.any(0, True)
|
||||
return all_, any_, all_0, any_0
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.array([[True, False, True], [True, False, False]]))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
ret = net(x)
|
||||
assert ret[0].asnumpy() == np.array(False)
|
||||
assert ret[1].asnumpy() == np.array(True)
|
||||
assert ret[2].asnumpy().shape == np.array([[True, False, False]]).shape
|
||||
assert (ret[2].asnumpy() == np.array([[True, False, False]])).all()
|
||||
assert ret[3].shape == Tensor(np.array([[True, False, True]])).shape
|
||||
assert (ret[3] == Tensor(np.array([[True, False, True]]))).all()
|
|
@ -194,7 +194,19 @@ def vm_impl_all(self):
|
|||
|
||||
def vm_impl(x, axis):
|
||||
x = x.asnumpy()
|
||||
out = vm.all(x, axis)
|
||||
out = vm.all(x, axis, self.keep_dims)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.ReduceAny)
|
||||
def vm_impl_any(self):
|
||||
"""Generate vm_impl function for Any"""
|
||||
|
||||
def vm_impl(x, axis):
|
||||
x = x.asnumpy()
|
||||
out = vm.any(x, axis, self.keep_dims)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
|
|
@ -67,3 +67,5 @@ setattr(vm, "tanh", tanh)
|
|||
setattr(vm, "sigmoid", sigmoid)
|
||||
setattr(vm, 'maximum', maximum)
|
||||
setattr(vm, 'minimum', minimum)
|
||||
setattr(vm, 'all', all_)
|
||||
setattr(vm, 'any', any_)
|
||||
|
|
|
@ -840,3 +840,35 @@ def minimum(x, y):
|
|||
numpy.ndarray, has the same type as x.
|
||||
"""
|
||||
return np.minimum(x, y)
|
||||
|
||||
|
||||
def all_(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check all array elements along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray): An array to be reduced.
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, has the same type as x.
|
||||
"""
|
||||
axis = None if axis == () else axis
|
||||
return np.all(x, axis, keepdims=keep_dims)
|
||||
|
||||
|
||||
def any_(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check any array element along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray): An array to be reduced.
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, has the same type as x.
|
||||
"""
|
||||
axis = None if axis == () else axis
|
||||
return np.any(x, axis, keepdims=keep_dims)
|
||||
|
|
Loading…
Reference in New Issue