From 0d962a372fb74a0aff793b1190a3fa338f4e0b33 Mon Sep 17 00:00:00 2001 From: huanghui Date: Sat, 15 Aug 2020 11:24:09 +0800 Subject: [PATCH] add ScatterNdUpdate cpu kernel --- .../kernel_compiler/cpu/reshape_cpu_kernel.cc | 1 - .../kernel_compiler/cpu/reshape_cpu_kernel.h | 6 + .../cpu/scatter_nd_update_cpu_kernel.cc | 123 ++++++++++++++++++ .../cpu/scatter_nd_update_cpu_kernel.h | 61 +++++++++ tests/st/ops/cpu/test_scatter_nd_update_op.py | 104 +++++++++++++++ 5 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_scatter_nd_update_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc index 6370fdc78a..e0183307f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc @@ -38,7 +38,6 @@ bool ReshapeCPUKernel::Launch(const std::vector &inputs, auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - return false; } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h index 915e1e8616..282c55133f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -37,16 +37,22 @@ MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp ReshapeCPUKernel); MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ReshapeCPUKernel); MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReshapeCPUKernel); MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ReshapeCPUKernel); MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReshapeCPUKernel); MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ReshapeCPUKernel); +MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ReshapeCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc new file mode 100644 index 0000000000..b8584f8c4c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h" +#include +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { + Check(kernel_node); + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + if (indices_shape.size() < 2) { + MS_LOG(EXCEPTION) << "Indices' dimension less than 2"; + } + auto indices_unit_rank = indices_shape.back(); + if (indices_unit_rank > shape.size()) { + MS_LOG(EXCEPTION) << "Value of last dimension of indices is greater than shape rank"; + } + if (indices_shape.size() < 2) { + MS_LOG(EXCEPTION) << "Indices dimension less than 2"; + } + if (updates_shape.size() != indices_shape.size() - 1 + shape.size() - indices_unit_rank) { + MS_LOG(EXCEPTION) << "Update, shape rank and indices rank inconsistent"; + } + for (size_t i = 0; i < indices_shape.size() - 1; ++i) { + if (updates_shape[i] != indices_shape[i]) { + MS_LOG(EXCEPTION) << "Value of " << i << "th dimension of indices is not equal to that update"; + } + } + indices_unit_rank_ = SizeToInt(indices_unit_rank); + unit_size_ = 1; + for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) { + unit_size_ *= SizeToInt(updates_shape[i]); + } + num_units_ = 1; + num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]); + for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) { + num_units_ *= SizeToInt(updates_shape[i]); + } + int out_stride = 1; + out_strides_.push_back(out_stride); + for (int i = indices_unit_rank_ - 2; i >= 0; i--) { + out_stride *= shape[i + 1]; + out_strides_.push_back(out_stride); + } + shape_ = shape; + output_unit_offsets_.reserve(num_units_); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); +} + +bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(ERROR) << "Only support float16, float32"; + return false; + } + return true; +} + +template +void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto x = reinterpret_cast(inputs[0]->addr); + auto indices = reinterpret_cast(inputs[1]->addr); + auto updates = reinterpret_cast(inputs[2]->addr); + auto y = reinterpret_cast(outputs[0]->addr); + + for (int i = 0; i < num_units_; ++i) { + int offset = 0; + for (int j = 0; j < indices_unit_rank_; ++j) { + offset += indices[i * indices_unit_rank_ + j] * out_strides_[j] * unit_size_; + } + output_unit_offsets_[i] = offset; + } + + auto mem_bits = outputs[0]->size; + auto ret = memcpy_s(y, mem_bits, x, mem_bits); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } + + for (int i = 0; i < num_units_; i++) { + ret = + memcpy_s(y + output_unit_offsets_[i], unit_size_ * sizeof(T), updates + unit_size_ * i, unit_size_ * sizeof(T)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } + } +} + +void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ScatterNdUpdate needs 3 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ScatterNdUpdate needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h new file mode 100644 index 0000000000..b5ee09295a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ScatterNdUpdateCPUKernel : public CPUKernel { + public: + ScatterNdUpdateCPUKernel() = default; + ~ScatterNdUpdateCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void Check(const CNodePtr &kernel_node); + TypeId dtype_; + int unit_size_; + int num_units_; + int indices_unit_rank_; + std::vector shape_; + std::vector output_unit_offsets_; + std::vector out_strides_; +}; + +MS_REG_CPU_KERNEL(ScatterNdUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ScatterNdUpdateCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_scatter_nd_update_op.py b/tests/st/ops/cpu/test_scatter_nd_update_op.py new file mode 100644 index 0000000000..e563c55faa --- /dev/null +++ b/tests/st/ops/cpu/test_scatter_nd_update_op.py @@ -0,0 +1,104 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=True) + + +class ScatterNdUpdate1(nn.Cell): + def __init__(self): + super(ScatterNdUpdate1, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + + +class ScatterNdUpdate2(nn.Cell): + def __init__(self): + super(ScatterNdUpdate2, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + + +class ScatterNdUpdate3(nn.Cell): + def __init__(self): + super(ScatterNdUpdate3, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_op1(): + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) + update = Tensor(np.array([1.0, 2.2]), mstype.float32) + + scatter_nd_update = ScatterNdUpdate1() + output = scatter_nd_update(indices, update) + print("output:\n", output) + expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]] + assert np.allclose(output.asnumpy(), np.array(expect, np.float)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_op2(): + indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32) + update = Tensor(np.array([9, 10, 11, 12]), mstype.float32) + + scatter_nd_update = ScatterNdUpdate2() + output = scatter_nd_update(indices, update) + print("output:\n", output) + expect = [1, 11, 3, 10, 9, 6, 7, 12] + assert np.allclose(output.asnumpy(), np.array(expect, dtype=float)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_op3(): + indices = Tensor(np.array([[0], [2]]), mstype.int32) + update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]], + [[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32) + + scatter_nd_update = ScatterNdUpdate3() + output = scatter_nd_update(indices, update) + print("output:\n", output) + expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))