forked from mindspore-Ecosystem/mindspore
support interface 'all' and 'any' of tensor
This commit is contained in:
parent
b0b4fa08b3
commit
2c4cb49a11
|
@ -31,6 +31,41 @@ trans = P.Transpose()
|
||||||
shape_ = P.Shape()
|
shape_ = P.Shape()
|
||||||
dtype_ = P.DType()
|
dtype_ = P.DType()
|
||||||
|
|
||||||
|
|
||||||
|
def all_(x, axis=(), keep_dims=False):
|
||||||
|
"""
|
||||||
|
Check all array elements along a given axis evaluate to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): A Tensor to be reduced.
|
||||||
|
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||||
|
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same data type as x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reduce_all = P.ReduceAll(keep_dims)
|
||||||
|
return reduce_all(x, axis)
|
||||||
|
|
||||||
|
|
||||||
|
def any_(x, axis=(), keep_dims=False):
|
||||||
|
"""
|
||||||
|
Check any array element along a given axis evaluate to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): A Tensor to be reduced.
|
||||||
|
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||||
|
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same data type as x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reduce_any = P.ReduceAny(keep_dims)
|
||||||
|
return reduce_any(x, axis)
|
||||||
|
|
||||||
|
|
||||||
def transpose(x):
|
def transpose(x):
|
||||||
"""Implementation of `transpose`."""
|
"""Implementation of `transpose`."""
|
||||||
shape = F.shape(x)
|
shape = F.shape(x)
|
||||||
|
@ -157,7 +192,6 @@ def check_is_const_int(x, op_name, arg_name):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_is_tensor_bool_cond(shp):
|
def check_is_tensor_bool_cond(shp):
|
||||||
"""check if tensor is a bool condition"""
|
"""check if tensor is a bool condition"""
|
||||||
|
@ -316,4 +350,5 @@ def to_array(x):
|
||||||
"""Implementation of `to_array`."""
|
"""Implementation of `to_array`."""
|
||||||
return x.__ms_to_array__()
|
return x.__ms_to_array__()
|
||||||
|
|
||||||
|
|
||||||
tensor_operator_registry.register('__bool__', tensor_bool)
|
tensor_operator_registry.register('__bool__', tensor_bool)
|
||||||
|
|
|
@ -143,6 +143,8 @@ BuiltInTypeMap &GetMethodMap() {
|
||||||
}},
|
}},
|
||||||
{kObjectTypeTensorType,
|
{kObjectTypeTensorType,
|
||||||
{
|
{
|
||||||
|
{"all", std::string("all_")}, // C.reduce_all
|
||||||
|
{"any", std::string("any_")}, // C.reduce_any
|
||||||
{"__add__", std::string("add")}, // C.add
|
{"__add__", std::string("add")}, // C.add
|
||||||
{"__sub__", std::string("sub")}, // C.sub
|
{"__sub__", std::string("sub")}, // C.sub
|
||||||
{"__mul__", std::string("mul")}, // C.mul
|
{"__mul__", std::string("mul")}, // C.mul
|
||||||
|
|
|
@ -35,9 +35,11 @@ class Registry(UserDict):
|
||||||
new_args = list(args)
|
new_args = list(args)
|
||||||
new_args.append(obj_str)
|
new_args.append(obj_str)
|
||||||
return self["vm_compare"](*new_args)
|
return self["vm_compare"](*new_args)
|
||||||
|
|
||||||
obj = wrap
|
obj = wrap
|
||||||
else:
|
else:
|
||||||
obj = self[obj_str]
|
obj = self[obj_str]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
tensor_operator_registry = Registry()
|
tensor_operator_registry = Registry()
|
||||||
|
|
|
@ -27,7 +27,6 @@ np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||||
np.float32, np.float64, np.bool_)
|
np.float32, np.float64, np.bool_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Tensor(Tensor_):
|
class Tensor(Tensor_):
|
||||||
"""
|
"""
|
||||||
Tensor for data storage.
|
Tensor for data storage.
|
||||||
|
@ -205,6 +204,34 @@ class Tensor(Tensor_):
|
||||||
return "Unknown Tensor type!"
|
return "Unknown Tensor type!"
|
||||||
return str(self.asnumpy())
|
return str(self.asnumpy())
|
||||||
|
|
||||||
|
def all(self, axis=(), keep_dims=False):
|
||||||
|
"""
|
||||||
|
Check all array elements along a given axis evaluate to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||||
|
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same data type as x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
|
||||||
|
|
||||||
|
def any(self, axis=(), keep_dims=False):
|
||||||
|
"""
|
||||||
|
Check any array element along a given axis evaluate to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||||
|
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same data type as x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def virtual_flag(self):
|
def virtual_flag(self):
|
||||||
"""Mark tensor is virtual."""
|
"""Mark tensor is virtual."""
|
||||||
|
@ -257,6 +284,7 @@ class IndexedSlices:
|
||||||
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
>>> values = Tensor([[1, 2]], dtype=ms.float32)
|
||||||
>>> Net((3, 2))(indices, values)
|
>>> Net((3, 2))(indices, values)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -297,5 +325,6 @@ class SparseTensor:
|
||||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||||
>>> Net((3, 4))(indices, values)
|
>>> Net((3, 4))(indices, values)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -30,7 +30,6 @@ dtype = P.DType()
|
||||||
isconstant = Primitive('is_constant')
|
isconstant = Primitive('is_constant')
|
||||||
isconstant.add_prim_attr('const_value', True)
|
isconstant.add_prim_attr('const_value', True)
|
||||||
|
|
||||||
|
|
||||||
issubclass_ = P.IsSubClass()
|
issubclass_ = P.IsSubClass()
|
||||||
isinstance_ = P.IsInstance()
|
isinstance_ = P.IsInstance()
|
||||||
fill = P.Fill()
|
fill = P.Fill()
|
||||||
|
@ -67,6 +66,7 @@ assign_sub = P.AssignSub()
|
||||||
assign = P.Assign()
|
assign = P.Assign()
|
||||||
square = P.Square()
|
square = P.Square()
|
||||||
sqrt = P.Sqrt()
|
sqrt = P.Sqrt()
|
||||||
|
|
||||||
scalar_to_array = P.ScalarToArray()
|
scalar_to_array = P.ScalarToArray()
|
||||||
scalar_to_tensor = P.ScalarToTensor()
|
scalar_to_tensor = P.ScalarToTensor()
|
||||||
tuple_to_array = P.TupleToArray()
|
tuple_to_array = P.TupleToArray()
|
||||||
|
@ -83,7 +83,6 @@ partial = P.Partial()
|
||||||
# depend: mount a node to another node
|
# depend: mount a node to another node
|
||||||
depend = P.Depend()
|
depend = P.Depend()
|
||||||
|
|
||||||
|
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
list_getitem = Primitive('list_getitem')
|
list_getitem = Primitive('list_getitem')
|
||||||
|
@ -102,7 +101,6 @@ tuple_equal = Primitive("tuple_equal")
|
||||||
list_equal = Primitive("list_equal")
|
list_equal = Primitive("list_equal")
|
||||||
make_ref = Primitive("make_ref")
|
make_ref = Primitive("make_ref")
|
||||||
|
|
||||||
|
|
||||||
scalar_add = Primitive('scalar_add')
|
scalar_add = Primitive('scalar_add')
|
||||||
scalar_mul = Primitive('scalar_mul')
|
scalar_mul = Primitive('scalar_mul')
|
||||||
scalar_sub = Primitive('scalar_sub')
|
scalar_sub = Primitive('scalar_sub')
|
||||||
|
@ -154,7 +152,6 @@ shape_mul = Primitive("shape_mul")
|
||||||
# a primitive to compare between tuple.
|
# a primitive to compare between tuple.
|
||||||
stop_gradient = Primitive("stop_gradient")
|
stop_gradient = Primitive("stop_gradient")
|
||||||
|
|
||||||
|
|
||||||
make_indexed_slices = Primitive('MakeIndexedSlices')
|
make_indexed_slices = Primitive('MakeIndexedSlices')
|
||||||
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||||
|
@ -172,7 +169,9 @@ tensor_operator_registry.register('__truediv__', tensor_div)
|
||||||
tensor_operator_registry.register('__mod__', tensor_mod)
|
tensor_operator_registry.register('__mod__', tensor_mod)
|
||||||
tensor_operator_registry.register('__pow__', tensor_pow)
|
tensor_operator_registry.register('__pow__', tensor_pow)
|
||||||
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
|
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
|
||||||
#ms cannot support Tensor(True) compare
|
tensor_operator_registry.register('all', P.ReduceAll)
|
||||||
|
tensor_operator_registry.register('any', P.ReduceAny)
|
||||||
|
# ms cannot support Tensor(True) compare
|
||||||
tensor_operator_registry.register('__eq__', equal)
|
tensor_operator_registry.register('__eq__', equal)
|
||||||
tensor_operator_registry.register('__ne__', not_equal)
|
tensor_operator_registry.register('__ne__', not_equal)
|
||||||
tensor_operator_registry.register('__neg__', neg_tensor)
|
tensor_operator_registry.register('__neg__', neg_tensor)
|
||||||
|
@ -181,6 +180,6 @@ tensor_operator_registry.register('__le__', tensor_le)
|
||||||
tensor_operator_registry.register('__gt__', tensor_gt)
|
tensor_operator_registry.register('__gt__', tensor_gt)
|
||||||
tensor_operator_registry.register('__ge__', tensor_ge)
|
tensor_operator_registry.register('__ge__', tensor_ge)
|
||||||
tensor_operator_registry.register('shape', shape)
|
tensor_operator_registry.register('shape', shape)
|
||||||
#support GE backend for no compare operators
|
# support GE backend for no compare operators
|
||||||
tensor_operator_registry.register('vm_compare', BP.vm_compare)
|
tensor_operator_registry.register('vm_compare', BP.vm_compare)
|
||||||
tensor_operator_registry.register('cast', cast)
|
tensor_operator_registry.register('cast', cast)
|
||||||
|
|
|
@ -1111,6 +1111,7 @@ class Mul(_MathBinaryOp):
|
||||||
>>> mul(input_x, input_y)
|
>>> mul(input_x, input_y)
|
||||||
[4, 10, 18]
|
[4, 10, 18]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def infer_value(self, x, y):
|
def infer_value(self, x, y):
|
||||||
if x is not None and y is not None:
|
if x is not None and y is not None:
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test interface 'all' and 'any' of tensor """
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_and_any_of_tensor_in_graph():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
all_ = x.all()
|
||||||
|
any_ = x.any()
|
||||||
|
all_0 = x.all(0, True)
|
||||||
|
any_0 = x.any(0, True)
|
||||||
|
return all_, any_, all_0, any_0
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
x = Tensor(np.array([[True, False, False], [True, False, False]]))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_and_any_of_tensor_in_pynative():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
all_ = x.all()
|
||||||
|
any_ = x.any()
|
||||||
|
all_0 = x.all(0, True)
|
||||||
|
any_0 = x.any(0, True)
|
||||||
|
return all_, any_, all_0, any_0
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
x = Tensor(np.array([[True, False, True], [True, False, False]]))
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
ret = net(x)
|
||||||
|
assert ret[0].asnumpy() == np.array(False)
|
||||||
|
assert ret[1].asnumpy() == np.array(True)
|
||||||
|
assert ret[2].asnumpy().shape == np.array([[True, False, False]]).shape
|
||||||
|
assert (ret[2].asnumpy() == np.array([[True, False, False]])).all()
|
||||||
|
assert ret[3].shape == Tensor(np.array([[True, False, True]])).shape
|
||||||
|
assert (ret[3] == Tensor(np.array([[True, False, True]]))).all()
|
|
@ -194,7 +194,19 @@ def vm_impl_all(self):
|
||||||
|
|
||||||
def vm_impl(x, axis):
|
def vm_impl(x, axis):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
out = vm.all(x, axis)
|
out = vm.all(x, axis, self.keep_dims)
|
||||||
|
return Tensor(out)
|
||||||
|
|
||||||
|
return vm_impl
|
||||||
|
|
||||||
|
|
||||||
|
@vm_impl_getters.register(P.ReduceAny)
|
||||||
|
def vm_impl_any(self):
|
||||||
|
"""Generate vm_impl function for Any"""
|
||||||
|
|
||||||
|
def vm_impl(x, axis):
|
||||||
|
x = x.asnumpy()
|
||||||
|
out = vm.any(x, axis, self.keep_dims)
|
||||||
return Tensor(out)
|
return Tensor(out)
|
||||||
|
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
|
@ -67,3 +67,5 @@ setattr(vm, "tanh", tanh)
|
||||||
setattr(vm, "sigmoid", sigmoid)
|
setattr(vm, "sigmoid", sigmoid)
|
||||||
setattr(vm, 'maximum', maximum)
|
setattr(vm, 'maximum', maximum)
|
||||||
setattr(vm, 'minimum', minimum)
|
setattr(vm, 'minimum', minimum)
|
||||||
|
setattr(vm, 'all', all_)
|
||||||
|
setattr(vm, 'any', any_)
|
||||||
|
|
|
@ -840,3 +840,35 @@ def minimum(x, y):
|
||||||
numpy.ndarray, has the same type as x.
|
numpy.ndarray, has the same type as x.
|
||||||
"""
|
"""
|
||||||
return np.minimum(x, y)
|
return np.minimum(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def all_(x, axis=(), keep_dims=False):
|
||||||
|
"""
|
||||||
|
Check all array elements along a given axis evaluate to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (numpy.ndarray): An array to be reduced.
|
||||||
|
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||||
|
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray, has the same type as x.
|
||||||
|
"""
|
||||||
|
axis = None if axis == () else axis
|
||||||
|
return np.all(x, axis, keepdims=keep_dims)
|
||||||
|
|
||||||
|
|
||||||
|
def any_(x, axis=(), keep_dims=False):
|
||||||
|
"""
|
||||||
|
Check any array element along a given axis evaluate to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (numpy.ndarray): An array to be reduced.
|
||||||
|
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||||
|
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray, has the same type as x.
|
||||||
|
"""
|
||||||
|
axis = None if axis == () else axis
|
||||||
|
return np.any(x, axis, keepdims=keep_dims)
|
||||||
|
|
Loading…
Reference in New Issue