Add Ones and Zeros operators

This commit is contained in:
l00591931 2020-11-12 15:07:53 +08:00
parent c5b5a6719c
commit 886ef520d7
3 changed files with 102 additions and 1 deletions

View File

@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions.
from .image_ops import (CropAndResize)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,

View File

@ -998,6 +998,93 @@ class Fill(PrimitiveWithInfer):
return out
class Ones(PrimitiveWithInfer):
"""
Creates a tensor filled with value ones.
Creates a tensor with shape described by the first argument and
fills it with value ones in type of the second argument.
Inputs:
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs:
Tensor, has the same type and shape as input value.
Examples:
>>> ones = P.Ones()
>>> Ones((2, 2), mindspore.float32)
[[1.0, 1.0],
[1.0, 1.0]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Fill"""
def __infer__(self, dims, dtype):
validator.check_value_type("shape", dims['value'], [tuple], self.name)
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.ones(dims['value'], x_nptype)
out = {
'value': Tensor(ret),
'shape': dims['value'],
'dtype': x_nptype,
}
return out
class Zeros(PrimitiveWithInfer):
"""
Creates a tensor filled with value zeros.
Creates a tensor with shape described by the first argument and
fills it with value zeros in type of the second argument.
Inputs:
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs:
Tensor, has the same type and shape as input value.
Examples:
>>> zeros = P.Zeros()
>>> Zeros((2, 2), mindspore.float32)
[[0.0, 0.0],
[0.0, 0.0]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Fill"""
def __infer__(self, dims, dtype):
validator.check_value_type("shape", dims['value'], [tuple], self.name)
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.zeros(dims['value'], x_nptype)
out = {
'value': Tensor(ret),
'shape': dims['value'],
'dtype': x_nptype,
}
return out
class OnesLike(PrimitiveWithInfer):
"""
Creates a new tensor. The values of all elements are 1.

View File

@ -52,6 +52,20 @@ def test_cast():
assert np.all(result.asnumpy() == expect)
def test_ones():
ones = P.Ones()
output = ones((2, 3), mstype.int32)
assert output.asnumpy().shape == (2, 3)
assert np.sum(output.asnumpy()) == 6
def test_zeros():
zeros = P.Zeros()
output = zeros((2, 3), mstype.int32)
assert output.asnumpy().shape == (2, 3)
assert np.sum(output.asnumpy()) == 0
@non_graph_engine
def test_reshape():
input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))