forked from mindspore-Ecosystem/mindspore
!47698 Add Ones/Zeros operators backend kernel
Merge pull request !47698 from yangshuo/br_01
This commit is contained in:
commit
fbfa41d834
|
@ -81,7 +81,6 @@ std::vector<std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc>> FillV2C
|
|||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
|
||||
|
|
|
@ -104,7 +104,6 @@ std::vector<std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc>> FillV2G
|
|||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
|
||||
|
|
|
@ -848,8 +848,18 @@ def ones(shape, dtype=None): # pylint: disable=redefined-outer-name
|
|||
[1. 1.]]
|
||||
"""
|
||||
_dtype = mstype.float32 if dtype is None else dtype
|
||||
ones_op = P.Ones()
|
||||
output = ones_op(shape, _dtype)
|
||||
ones_op = P.FillV2()
|
||||
value = Tensor(1, _dtype)
|
||||
if isinstance(shape, int):
|
||||
shape = tuple([shape])
|
||||
shape_tensor = shape
|
||||
if isinstance(shape, (list, tuple)) and not shape:
|
||||
shape_tensor = Tensor(shape, dtype=mstype.int64)
|
||||
elif not isinstance(shape, Tensor):
|
||||
shape_tensor = Tensor(shape)
|
||||
if shape_tensor.ndim == 0 and shape_tensor.size == 1:
|
||||
shape_tensor = shape_tensor.reshape(1)
|
||||
output = ones_op(shape_tensor, value)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -905,9 +915,19 @@ def zeros(shape, dtype=None): # pylint: disable=redefined-outer-name
|
|||
[[0. 0.]
|
||||
[0. 0.]]
|
||||
"""
|
||||
zero_op = P.Zeros()
|
||||
zero_op = P.FillV2()
|
||||
_dtype = mstype.float32 if dtype is None else dtype
|
||||
output = zero_op(shape, _dtype)
|
||||
value = Tensor(0, _dtype)
|
||||
if isinstance(shape, int):
|
||||
shape = tuple([shape])
|
||||
shape_tensor = shape
|
||||
if isinstance(shape, (list, tuple)) and not shape:
|
||||
shape_tensor = Tensor(shape, dtype=mstype.int64)
|
||||
elif not isinstance(shape, Tensor):
|
||||
shape_tensor = Tensor(shape)
|
||||
if shape_tensor.ndim == 0 and shape_tensor.size == 1:
|
||||
shape_tensor = shape_tensor.reshape(1)
|
||||
output = zero_op(shape_tensor, value)
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -1492,7 +1492,7 @@ class Fills(Primitive):
|
|||
self.init_prim_io_names(inputs=['x', 'value'], outputs=['y'])
|
||||
|
||||
|
||||
class FillV2(Primitive):
|
||||
class FillV2(PrimitiveWithCheck):
|
||||
"""
|
||||
Creates a tensor with shape described by `shape` and fills it with values in `value` .
|
||||
|
||||
|
@ -1534,11 +1534,11 @@ class FillV2(Primitive):
|
|||
self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
|
||||
|
||||
def infer_value(self, dims, x):
|
||||
if isinstance(dims, Tensor_):
|
||||
dims = dims.asnumpy()
|
||||
if isinstance(x, Tensor_):
|
||||
x = x.asnumpy()
|
||||
if dims is not None and None not in dims and x is not None:
|
||||
if isinstance(dims, Tensor):
|
||||
dims = dims.asnumpy()
|
||||
if isinstance(x, Tensor):
|
||||
x = x.asnumpy()
|
||||
ret = np.full(dims, x)
|
||||
return Tensor(ret)
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import function as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import dtype_to_nptype
|
||||
|
||||
|
||||
class OnesNetDynTensor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OnesNetDynTensor, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.gather = P.Gather()
|
||||
self.x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int32))
|
||||
self.indices = Tensor(np.array([0, 1, 2, 6, 2, 1], dtype=np.int32))
|
||||
self.axis = 0
|
||||
|
||||
def construct(self, dtype):
|
||||
unique_indices, _ = self.unique(self.indices)
|
||||
input_x = self.gather(self.x, unique_indices, self.axis)
|
||||
return F.ones(input_x, dtype)
|
||||
|
||||
|
||||
def dyn_shape_tensor_run():
|
||||
net = OnesNetDynTensor()
|
||||
out = net(mstype.float32)
|
||||
expect = np.ones((1, 2, 3, 7), dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect)
|
||||
|
||||
|
||||
def ones_func_run(shape, dtype):
|
||||
output = F.ones(shape, dtype)
|
||||
expect = np.ones(shape, dtype_to_nptype(dtype))
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ones_dynamic_shape():
|
||||
"""
|
||||
Feature: test graph mode
|
||||
Description: compare result with numpy
|
||||
Expectation: calculate result same to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||
dyn_shape_tensor_run()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="CPU")
|
||||
dyn_shape_tensor_run()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ones_func_pynative_mode():
|
||||
"""
|
||||
Feature: test pynative mode
|
||||
Description: compare result with numpy
|
||||
Expectation: calculate result same to numpy
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="CPU")
|
||||
ones_func_run((2, 3), mstype.float32)
|
||||
ones_func_run((2,), mstype.float16)
|
||||
ones_func_run((2, 3, 4, 5), mstype.int32)
|
||||
ones_func_run((1, 64), mstype.int8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ones_func_graph_mode():
|
||||
"""
|
||||
Feature: test graph mode
|
||||
Description: compare result with numpy
|
||||
Expectation: calculate result same to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||
ones_func_run((2, 3), mstype.float32)
|
||||
ones_func_run((2,), mstype.float16)
|
||||
ones_func_run((2, 3, 4, 5), mstype.int32)
|
||||
ones_func_run((1, 64), mstype.int8)
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import function as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import dtype_to_nptype
|
||||
|
||||
|
||||
class ZerosNetDynTensor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ZerosNetDynTensor, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.gather = P.Gather()
|
||||
self.x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int32))
|
||||
self.indices = Tensor(np.array([0, 1, 2, 6, 2, 1], dtype=np.int32))
|
||||
self.axis = 0
|
||||
|
||||
def construct(self, dtype):
|
||||
unique_indices, _ = self.unique(self.indices)
|
||||
input_x = self.gather(self.x, unique_indices, self.axis)
|
||||
return F.zeros(input_x, dtype)
|
||||
|
||||
|
||||
def dyn_shape_tensor_run():
|
||||
net = ZerosNetDynTensor()
|
||||
out = net(mstype.float32)
|
||||
expect = np.zeros((1, 2, 3, 7), dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect)
|
||||
|
||||
|
||||
def zeros_func_run(shape, dtype):
|
||||
output = F.zeros(shape, dtype)
|
||||
expect = np.zeros(shape, dtype_to_nptype(dtype))
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_zeros_dynamic_shape():
|
||||
"""
|
||||
Feature: test graph mode
|
||||
Description: compare result with numpy
|
||||
Expectation: calculate result same to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||
dyn_shape_tensor_run()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="CPU")
|
||||
dyn_shape_tensor_run()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_zeros_func_pynative_mode():
|
||||
"""
|
||||
Feature: test pynative mode
|
||||
Description: compare result with numpy
|
||||
Expectation: calculate result same to numpy
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="CPU")
|
||||
zeros_func_run((2, 3), mstype.float32)
|
||||
zeros_func_run((2,), mstype.float16)
|
||||
zeros_func_run((2, 3, 4, 5), mstype.int32)
|
||||
zeros_func_run((1, 64), mstype.int8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_zeros_func_graph_mode():
|
||||
"""
|
||||
Feature: test graph mode
|
||||
Description: compare result with numpy
|
||||
Expectation: calculate result same to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||
zeros_func_run((2, 3), mstype.float32)
|
||||
zeros_func_run((2,), mstype.float16)
|
||||
zeros_func_run((2, 3, 4, 5), mstype.int32)
|
||||
zeros_func_run((1, 64), mstype.int8)
|
|
@ -343,3 +343,16 @@ def vm_impl_load(self):
|
|||
return value
|
||||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.FillV2)
|
||||
def vm_impl_fillv2(self):
|
||||
def vm_impl(x, y):
|
||||
if isinstance(x, Tensor):
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = np.empty(x).astype(y.dtype)
|
||||
out.fill(y)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
|
Loading…
Reference in New Issue