!35440 Add testcases for ScatterUpdate and unify the testing codes of ScatterMin/Max/Update

Merge pull request !35440 from zhengzuohe/test_scatterupdate
This commit is contained in:
i-robot 2022-06-09 03:20:59 +00:00 committed by Gitee
commit 4f07986eae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 479 additions and 530 deletions

View File

@ -81,7 +81,7 @@ TypePtr ScatterArithmeticInferType(const PrimitivePtr &primitive, const std::vec
AbstractBasePtr ScatterArithmeticInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
constexpr int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ScatterArithmeticInferType(primitive, input_args);
auto infer_shape = ScatterArithmeticInferShape(primitive, input_args);

View File

@ -29,6 +29,8 @@ class MIND_API ScatterUpdate : public BaseOperator {
MIND_API_BASE_MEMBER(ScatterUpdate);
/// \brief Constructor.
ScatterUpdate() : BaseOperator(kNameScatterUpdate) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore

View File

@ -4267,6 +4267,11 @@ class ScatterUpdate(Primitive):
[[2. 1.2 1.]
[3. 1.2 1.]]
"""
__mindspore_signature__ = (
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):

View File

@ -0,0 +1,437 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops.functional import vmap
from mindspore import Tensor, Parameter, ParameterTuple
# all cases tested against dchip
func_map = {
"max": ops.ScatterMax,
"min": ops.ScatterMin,
"update": ops.ScatterUpdate,
}
class TestScatterFuncNet(nn.Cell):
def __init__(self, func, inputx):
super(TestScatterFuncNet, self).__init__()
self.scatter_func = func_map.get(func)()
self.inputx = Parameter(inputx, name="inputx")
def construct(self, indices, updates):
out = self.scatter_func(self.inputx, indices, updates)
return out
def scatter_func_forward(nptype):
inputx = Tensor(np.arange(0, 9).reshape((3, 3)).astype(nptype))
indices = Tensor(
np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(nptype))
# scatter_max
net = TestScatterFuncNet("max", inputx)
output = net(indices, updates)
expected = inputx.asnumpy()
expected = np.array(
[[55.0, 56.0, 57.0], [64.0, 65.0, 66.0], [67.0, 68.0, 69.0]]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_min
net = TestScatterFuncNet("min", inputx)
output = net(indices, updates)
expected = inputx.asnumpy()
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_update
if nptype not in (np.float16, np.float32):
return
net = TestScatterFuncNet("update", inputx)
output = net(indices, updates)
expected = inputx.asnumpy()
expected = np.array(
[[55.0, 56.0, 57.0], [64.0, 65.0, 66.0], [67.0, 68.0, 69.0]]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_func_dynamic_updates():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
updates_dy = Tensor(shape=(2, 2, 2, None, 4), dtype=mindspore.float32)
# scatter_max
net = TestScatterFuncNet("max", inputx)
net.set_inputs(indices, updates_dy)
output = net(indices, updates)
expected = np.array([[[[1, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]],
[[[72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83]],
[[84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95]]],
[[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]],
[[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]],
[[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]],
[[60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71]]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_min
net = TestScatterFuncNet("min", inputx)
net.set_inputs(indices, updates_dy)
output = net(indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.float32)
expected[0][0][0][0] = 0.0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_update
net = TestScatterFuncNet("update", inputx)
net.set_inputs(indices, updates_dy)
output = net(indices, updates)
expected = np.array([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]],
[[[72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83]],
[[84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95]]],
[[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]],
[[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]],
[[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]],
[[60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71]]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_func_dynamic_indices():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.int32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32)
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.int32))
# scatter_max
net = TestScatterFuncNet("max", inputx)
net.set_inputs(indices_dy, updates)
output = net(indices, updates)
expected = np.array([[[[1, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]],
[[[72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83]],
[[84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95]]],
[[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]],
[[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]],
[[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]],
[[60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71]]]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_min
net = TestScatterFuncNet("min", inputx)
net.set_inputs(indices_dy, updates)
output = net(indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.int32)
expected[0][0][0][0] = 0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_update
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32)
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
net = TestScatterFuncNet("update", inputx)
net.set_inputs(indices_dy, updates)
output = net(indices, updates)
expected = np.array([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]],
[[[72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83]],
[[84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95]]],
[[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]],
[[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]],
[[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]],
[[60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71]]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
class TestScatterFuncGradNet(nn.Cell):
def __init__(self, network):
super(TestScatterFuncGradNet, self).__init__()
self.grad = ops.GradOperation(
get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, indices, updates, dout):
out = self.grad(self.network, self.params)(indices, updates, dout)
return out
def scatter_func_grad(nptype):
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(nptype)))
indices = Tensor(
np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(nptype))
dout = Tensor(np.flip(np.arange(0, 12).reshape((3, 4)).astype(nptype)))
indices_expected = np.array(
[[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]).astype(nptype)
updates_expected = np.array(
[
[
[[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]],
[[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]
],
[
[[11, 10, 9, 8], [11, 10, 9, 8], [11, 10, 9, 8]],
[[3, 2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0]]
]
]).astype(nptype)
# scatter_max
net = TestScatterFuncGradNet(TestScatterFuncNet("max", inputx))
output = net(indices, updates, dout)
indices_grad = output[0][0]
updates_grad = output[0][1]
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
# scatter_min
net = TestScatterFuncGradNet(TestScatterFuncNet("min", inputx))
output = net(indices, updates, dout)
indices_grad = output[0][0]
updates_grad = output[0][1]
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
# scatter_update
if nptype not in (np.float16, np.float32):
return
net = TestScatterFuncGradNet(TestScatterFuncNet("update", inputx))
output = net(indices, updates, dout)
indices_grad = output[0][0]
updates_grad = output[0][1]
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
class ScatterFuncVmapNet(nn.Cell):
def __init__(self, func):
super(ScatterFuncVmapNet, self).__init__()
self.scatter_func = func_map.get(func)()
def construct(self, inputx, indices, updates):
return self.scatter_func(inputx, indices, updates)
class VmapNet(nn.Cell):
def __init__(self, net, inputx, in_axes, out_axes):
super(VmapNet, self).__init__()
self.net = net
self.in_axes = in_axes
self.out_axes = out_axes
self.inputx = Parameter(inputx, name="inputx")
def construct(self, indices, updates):
return vmap(self.net, self.in_axes, self.out_axes)(self.inputx, indices, updates)
def scatter_func_indices_vmap():
inputx = Parameter(Tensor(np.array(
[[[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]]]
).astype(np.float32)), name="inputx")
indices = Tensor(np.array(
[[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]]).astype(np.int32))
updates = Tensor(
np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]).astype(np.float32))
# scatter_update
output = VmapNet(ScatterFuncVmapNet("update"), inputx,
(0, 0, None), 0)(indices, updates)
expected = np.array(
[[[1, 1, 1], [4, 4, 4]], [[3, 3, 3], [4, 4, 4]], [[4, 4, 4], [3, 3, 3]]]
).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_func_updates_vmap():
inputx = Parameter(Tensor(np.array(
[[0.1, 1.0, 2.2], [3.0, 4.3, 5.5]]).astype(np.float32)), name="inputx")
indices = Tensor(np.array([0, 1]).astype(np.int32))
updates = Tensor(np.array([[1.0, 0.1], [1.2, 1.3]]).astype(np.float32))
# scatter_update
output = VmapNet(ScatterFuncVmapNet("update"), inputx,
(0, None, 0), 0)(indices, updates)
expected = np.array([[1.0, 0.1, 2.2], [1.2, 1.3, 5.5]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_forward_float16():
"""
Feature: test scatter_func forward.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_forward(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_forward(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_forward_float32():
"""
Feature: test scatter_func forward.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_forward(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_forward(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_forward_int32():
"""
Feature: test scatter_func forward.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_forward(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_forward(np.int32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_dynamic_indices():
"""
Feature: test scatter_func dynamic shape.
Description: indices is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_dynamic_indices()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_dynamic_indices()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_dynamic_updates():
"""
Feature: test scatter_func dynamic shape.
Description: updates is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_dynamic_updates()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_dynamic_updates()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_grad_float16():
"""
Feature: test scatter_func grad.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_grad(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_grad(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_grad_float32():
"""
Feature: test scatter_func grad.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_grad(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_grad_int32():
"""
Feature: test scatter_func grad.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_grad(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_grad(np.int32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_indices_vmap():
"""
Feature: test scatter_func vmap.
Description: in_axes: (0, 0, None).
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_indices_vmap()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_indices_vmap()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_func_updates_vmap():
"""
Feature: test scatter_func vmap.
Description: in_axes: (0, None, 0).
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_func_updates_vmap()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_func_updates_vmap()

View File

@ -1,271 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, ParameterTuple
# all cases tested against dchip
class TestScatterMaxNet(nn.Cell):
def __init__(self, inputx):
super(TestScatterMaxNet, self).__init__()
self.scatter_max = ops.ScatterMax()
self.inputx = Parameter(inputx, name="inputx")
def construct(self, indices, updates):
out = self.scatter_max(self.inputx, indices, updates)
return out
def scatter_max_forward(nptype):
inputx = Tensor(np.arange(0, 9).reshape((3, 3)).astype(nptype))
indices = Tensor(
np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(nptype))
net = TestScatterMaxNet(inputx)
output = net(indices, updates)
expected = inputx.asnumpy()
expected = np.array(
[[55.0, 56.0, 57.0], [64.0, 65.0, 66.0], [67.0, 68.0, 69.0]]).astype(nptype)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_max_dynamic_updates():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
updates_dy = Tensor(shape=(2, 2, 2, None, 4), dtype=mindspore.float32)
net = TestScatterMaxNet(inputx)
net.set_inputs(indices, updates_dy)
output = net(indices, updates)
expected = np.array([[[[1, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]],
[[[72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83]],
[[84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95]]],
[[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]],
[[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]],
[[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]],
[[60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71]]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_max_dynamic_indices():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.int32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32)
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.int32))
net = TestScatterMaxNet(inputx)
net.set_inputs(indices_dy, updates)
output = net(indices, updates)
expected = np.array([[[[1, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]],
[[[72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83]],
[[84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95]]],
[[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]],
[[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]],
[[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]],
[[60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71]]]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
class TestScatterMaxGradNet(nn.Cell):
def __init__(self, network):
super(TestScatterMaxGradNet, self).__init__()
self.grad = ops.GradOperation(
get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, indices, updates, dout):
out = self.grad(self.network, self.params)(indices, updates, dout)
return out
def scatter_max_grad(nptype):
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(nptype)))
indices = Tensor(
np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(nptype))
dout = Tensor(np.flip(np.arange(0, 12).reshape((3, 4)).astype(nptype)))
net = TestScatterMaxGradNet(TestScatterMaxNet(inputx))
output = net(indices, updates, dout)
indices_grad = output[0][0]
updates_grad = output[0][1]
indices_expected = np.array(
[[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]).astype(nptype)
updates_expected = np.array(
[
[
[[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]],
[[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]
],
[
[[11, 10, 9, 8], [11, 10, 9, 8], [11, 10, 9, 8]],
[[3, 2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0]]
]
]).astype(nptype)
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_forward_float16():
"""
Feature: test scatter_max forward.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_forward(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_forward(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_forward_float32():
"""
Feature: test scatter_max forward.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_forward(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_forward(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_forward_int32():
"""
Feature: test scatter_max forward.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_forward(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_forward(np.int32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_dynamic_indices():
"""
Feature: test scatter_max dynamic shape.
Description: indices is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_dynamic_indices()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_dynamic_indices()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_dynamic_updates():
"""
Feature: test scatter_max dynamic shape.
Description: updates is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_dynamic_updates()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_dynamic_updates()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_grad_float16():
"""
Feature: test scatter_max grad.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_grad(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_grad(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_grad_float32():
"""
Feature: test scatter_max grad.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_grad(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_max_grad_int32():
"""
Feature: test scatter_max grad.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_max_grad(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_max_grad(np.int32)
if __name__ == "__main__":
test_scatter_max_forward_float16()
test_scatter_max_forward_float32()
test_scatter_max_forward_int32()
test_scatter_max_dynamic_indices()
test_scatter_max_dynamic_updates()
test_scatter_max_grad_float16()
test_scatter_max_grad_float32()
test_scatter_max_grad_int32()

View File

@ -1,250 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, ParameterTuple
# all cases tested against dchip
class TestScatterMinNet(nn.Cell):
def __init__(self, inputx):
super(TestScatterMinNet, self).__init__()
self.scatter_min = ops.ScatterMin()
self.inputx = Parameter(inputx, name="inputx")
def construct(self, indices, updates):
out = self.scatter_min(self.inputx, indices, updates)
return out
def scatter_min_forward(nptype):
inputx = Tensor(np.arange(0, 9).reshape((3, 3)).astype(nptype))
indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(nptype))
net = TestScatterMinNet(inputx)
output = net(indices, updates)
expected = inputx.asnumpy()
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_min_dynamic_updates():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
updates_dy = Tensor(shape=(2, 2, 2, None, 4), dtype=mindspore.float32)
net = TestScatterMinNet(inputx)
net.set_inputs(indices, updates_dy)
output = net(indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.float32)
expected[0][0][0][0] = 0.0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_min_dynamic_indices():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.int32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32)
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.int32))
net = TestScatterMinNet(inputx)
net.set_inputs(indices_dy, updates)
output = net(indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.int32)
expected[0][0][0][0] = 0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
class TestScatterMinGradNet(nn.Cell):
def __init__(self, network):
super(TestScatterMinGradNet, self).__init__()
self.grad = ops.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, indices, updates, dout):
out = self.grad(self.network, self.params)(indices, updates, dout)
return out
def scatter_min_grad(nptype):
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(nptype)))
indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(nptype))
dout = Tensor(np.flip(np.arange(0, 12).reshape((3, 4)).astype(nptype)))
net = TestScatterMinGradNet(TestScatterMinNet(inputx))
output = net(indices, updates, dout)
indices_grad = output[0][0]
updates_grad = output[0][1]
indices_expected = np.array([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]).astype(nptype)
updates_expected = np.array(
[
[
[
[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]
],
[
[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]
]
],
[
[
[11, 10, 9, 8], [11, 10, 9, 8], [11, 10, 9, 8]
],
[
[3, 2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0]
]
]
]).astype(nptype)
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_forward_float16():
"""
Feature: test scatter_min forward.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_forward(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_forward(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_forward_float32():
"""
Feature: test scatter_min forward.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_forward(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_forward(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_forward_int32():
"""
Feature: test scatter_min forward.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_forward(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_forward(np.int32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_dynamic_indices():
"""
Feature: test scatter_min dynamic shape.
Description: indices is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_dynamic_indices()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_dynamic_indices()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_dynamic_updates():
"""
Feature: test scatter_min dynamic shape.
Description: updates is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_dynamic_updates()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_dynamic_updates()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_grad_float16():
"""
Feature: test scatter_min grad.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_grad(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_grad(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_grad_float32():
"""
Feature: test scatter_min grad.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_grad(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_grad_int32():
"""
Feature: test scatter_min grad.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_grad(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_grad(np.int32)

View File

@ -38,7 +38,7 @@ class TestScatterFuncNet(nn.Cell):
def __init__(self, func, lock, inputx, indices, updates):
super(TestScatterFuncNet, self).__init__()
self.scatter_func = func_map[func](use_locking=lock)
self.scatter_func = func_map.get(func)(use_locking=lock)
self.inputx = Parameter(inputx, name="inputx")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
@ -63,7 +63,7 @@ def scatter_func_use_locking_false_net(func, inputx, indices, updates):
class TestScatterFuncDynamicNet(nn.Cell):
def __init__(self, func, inputx, indices, updates):
super(TestScatterFuncDynamicNet, self).__init__()
self.scatter_func = func_map[func]()
self.scatter_func = func_map.get(func)()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.inputx = Parameter(inputx, name="inputx")
self.indices = Parameter(indices, name="indices")
@ -85,7 +85,7 @@ def scatter_func_d_net(func, inputx, indices, updates):
class TestScatterFuncDynamicNet2(nn.Cell):
def __init__(self, func, inputx):
super(TestScatterFuncDynamicNet2, self).__init__()
self.scatter_func = func_map[func]()
self.scatter_func = func_map.get(func)()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.inputx = Parameter(inputx, name="inputx")
@ -423,7 +423,12 @@ def test_scatter_func_input_less_than_1_float32():
# update
output = scatter_func_net("update", inputx, indices, updates)
expected = np.array(
[[37.0, 38.0, 39.0], [34.0, 35.0, 66.0], [67.0, 68.0, 69.0],], dtype=np.float32,
[
[37.0, 38.0, 39.0],
[34.0, 35.0, 66.0],
[67.0, 68.0, 69.0],
],
dtype=np.float32,
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@ -950,21 +955,35 @@ def test_scatter_func_indices_vmap():
).astype(np.int32)), name="inputx")
indices = Tensor(np.array([[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]]).astype(np.int32))
updates = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]).astype(np.int32))
in_axes = (0, 0, None)
out_axes = 0
# scatter_max
output = VmapNet(ScatterFuncVmapNet("max"), inputx, (0, 0, None), 0)(indices, updates)
output = VmapNet(ScatterFuncVmapNet("max"), inputx, in_axes, out_axes)(indices, updates)
expected = np.array(
[[[1, 1, 2], [4, 4, 5]], [[3, 3, 3], [4, 4, 5]], [[4, 4, 4], [3, 4, 5]]]
).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_min
output = VmapNet(ScatterFuncVmapNet("min"), inputx, (0, 0, None), 0)(indices, updates)
output = VmapNet(ScatterFuncVmapNet("min"), inputx, in_axes, out_axes)(indices, updates)
expected = np.array(
[[[0, 1, 1], [2, 2, 2]], [[0, 1, 1], [2, 2, 2]], [[0, 1, 2], [1, 1, 1]]]
).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_update
inputx = Parameter(Tensor(np.array(
[[[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]]]
).astype(np.float32)), name="inputx")
indices = Tensor(np.array([[0, 1], [1, 0], [0, 1]]).astype(np.int32))
updates = Tensor(np.array([[1, 1, 1], [2, 2, 2]]).astype(np.float32))
output = VmapNet(ScatterFuncVmapNet("update"), inputx, in_axes, out_axes)(indices, updates)
expected = np.array(
[[[1, 1, 1], [2, 2, 2]], [[2, 2, 2], [1, 1, 1]], [[1, 1, 1], [2, 2, 2]]]
).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -978,13 +997,20 @@ def test_scatter_func_updates_vmap():
inputx = Parameter(Tensor(np.array([[0.1, 1.0, 2.2], [3.0, 4.3, 5.5]]).astype(np.float32)), name="inputx")
indices = Tensor(np.array([0, 1]).astype(np.int32))
updates = Tensor(np.array([[1.0, 0.1], [1.2, 1.3]]).astype(np.float32))
in_axes = (0, None, 0)
out_axes = 0
# scatter_max
output = VmapNet(ScatterFuncVmapNet("max"), inputx, (0, None, 0), 0)(indices, updates)
output = VmapNet(ScatterFuncVmapNet("max"), inputx, in_axes, out_axes)(indices, updates)
expected = np.array([[1.0, 1.0, 2.2], [3.0, 4.3, 5.5]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_min
output = VmapNet(ScatterFuncVmapNet("min"), inputx, (0, None, 0), 0)(indices, updates)
output = VmapNet(ScatterFuncVmapNet("min"), inputx, in_axes, out_axes)(indices, updates)
expected = np.array([[0.1, 0.1, 2.2], [1.2, 1.3, 5.5]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# scatter_update
output = VmapNet(ScatterFuncVmapNet("update"), inputx, in_axes, out_axes)(indices, updates)
expected = np.array([[1.0, 0.1, 2.2], [1.2, 1.3, 5.5]]).astype(np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)