forked from mindspore-Ecosystem/mindspore
Add Ones and Zeros operators
This commit is contained in:
parent
c5b5a6719c
commit
886ef520d7
|
@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions.
|
|||
from .image_ops import (CropAndResize)
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
|
|
|
@ -998,6 +998,93 @@ class Fill(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Ones(PrimitiveWithInfer):
|
||||
"""
|
||||
Creates a tensor filled with value ones.
|
||||
|
||||
Creates a tensor with shape described by the first argument and
|
||||
fills it with value ones in type of the second argument.
|
||||
|
||||
Inputs:
|
||||
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
|
||||
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same type and shape as input value.
|
||||
|
||||
Examples:
|
||||
>>> ones = P.Ones()
|
||||
>>> Ones((2, 2), mindspore.float32)
|
||||
[[1.0, 1.0],
|
||||
[1.0, 1.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Fill"""
|
||||
|
||||
def __infer__(self, dims, dtype):
|
||||
validator.check_value_type("shape", dims['value'], [tuple], self.name)
|
||||
for i, item in enumerate(dims['value']):
|
||||
validator.check_positive_int(item, f'dims[{i}]', self.name)
|
||||
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint32, mstype.uint64,
|
||||
mstype.float16, mstype.float32, mstype.float64]
|
||||
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
|
||||
x_nptype = mstype.dtype_to_nptype(dtype['value'])
|
||||
ret = np.ones(dims['value'], x_nptype)
|
||||
out = {
|
||||
'value': Tensor(ret),
|
||||
'shape': dims['value'],
|
||||
'dtype': x_nptype,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class Zeros(PrimitiveWithInfer):
|
||||
"""
|
||||
Creates a tensor filled with value zeros.
|
||||
|
||||
Creates a tensor with shape described by the first argument and
|
||||
fills it with value zeros in type of the second argument.
|
||||
|
||||
Inputs:
|
||||
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
|
||||
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same type and shape as input value.
|
||||
|
||||
Examples:
|
||||
>>> zeros = P.Zeros()
|
||||
>>> Zeros((2, 2), mindspore.float32)
|
||||
[[0.0, 0.0],
|
||||
[0.0, 0.0]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Fill"""
|
||||
|
||||
def __infer__(self, dims, dtype):
|
||||
validator.check_value_type("shape", dims['value'], [tuple], self.name)
|
||||
for i, item in enumerate(dims['value']):
|
||||
validator.check_positive_int(item, f'dims[{i}]', self.name)
|
||||
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint32, mstype.uint64,
|
||||
mstype.float16, mstype.float32, mstype.float64]
|
||||
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
|
||||
x_nptype = mstype.dtype_to_nptype(dtype['value'])
|
||||
ret = np.zeros(dims['value'], x_nptype)
|
||||
out = {
|
||||
'value': Tensor(ret),
|
||||
'shape': dims['value'],
|
||||
'dtype': x_nptype,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class OnesLike(PrimitiveWithInfer):
|
||||
"""
|
||||
Creates a new tensor. The values of all elements are 1.
|
||||
|
|
|
@ -52,6 +52,20 @@ def test_cast():
|
|||
assert np.all(result.asnumpy() == expect)
|
||||
|
||||
|
||||
def test_ones():
|
||||
ones = P.Ones()
|
||||
output = ones((2, 3), mstype.int32)
|
||||
assert output.asnumpy().shape == (2, 3)
|
||||
assert np.sum(output.asnumpy()) == 6
|
||||
|
||||
|
||||
def test_zeros():
|
||||
zeros = P.Zeros()
|
||||
output = zeros((2, 3), mstype.int32)
|
||||
assert output.asnumpy().shape == (2, 3)
|
||||
assert np.sum(output.asnumpy()) == 0
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_reshape():
|
||||
input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))
|
||||
|
|
Loading…
Reference in New Issue