!4566 Add ScatterNdUpdate cpu kernel

Merge pull request !4566 from huanghui/scatter-nd-update-cpu-kernel
This commit is contained in:
mindspore-ci-bot 2020-08-17 19:33:04 +08:00 committed by Gitee
commit 8c72d5b9d1
5 changed files with 294 additions and 1 deletions

View File

@ -38,7 +38,6 @@ bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
return false;
}
return true;
}

View File

@ -37,16 +37,22 @@ MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReshapeCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,123 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h"
#include <string>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
Check(kernel_node);
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
if (indices_shape.size() < 2) {
MS_LOG(EXCEPTION) << "Indices' dimension less than 2";
}
auto indices_unit_rank = indices_shape.back();
if (indices_unit_rank > shape.size()) {
MS_LOG(EXCEPTION) << "Value of last dimension of indices is greater than shape rank";
}
if (indices_shape.size() < 2) {
MS_LOG(EXCEPTION) << "Indices dimension less than 2";
}
if (updates_shape.size() != indices_shape.size() - 1 + shape.size() - indices_unit_rank) {
MS_LOG(EXCEPTION) << "Update, shape rank and indices rank inconsistent";
}
for (size_t i = 0; i < indices_shape.size() - 1; ++i) {
if (updates_shape[i] != indices_shape[i]) {
MS_LOG(EXCEPTION) << "Value of " << i << "th dimension of indices is not equal to that update";
}
}
indices_unit_rank_ = SizeToInt(indices_unit_rank);
unit_size_ = 1;
for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) {
unit_size_ *= SizeToInt(updates_shape[i]);
}
num_units_ = 1;
num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]);
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) {
num_units_ *= SizeToInt(updates_shape[i]);
}
int out_stride = 1;
out_strides_.push_back(out_stride);
for (int i = indices_unit_rank_ - 2; i >= 0; i--) {
out_stride *= shape[i + 1];
out_strides_.push_back(out_stride);
}
shape_ = shape;
output_unit_offsets_.reserve(num_units_);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
}
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(ERROR) << "Only support float16, float32";
return false;
}
return true;
}
template <typename T>
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
auto y = reinterpret_cast<T *>(outputs[0]->addr);
for (int i = 0; i < num_units_; ++i) {
int offset = 0;
for (int j = 0; j < indices_unit_rank_; ++j) {
offset += indices[i * indices_unit_rank_ + j] * out_strides_[j] * unit_size_;
}
output_unit_offsets_[i] = offset;
}
auto mem_bits = outputs[0]->size;
auto ret = memcpy_s(y, mem_bits, x, mem_bits);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
for (int i = 0; i < num_units_; i++) {
ret =
memcpy_s(y + output_unit_offsets_[i], unit_size_ * sizeof(T), updates + unit_size_ * i, unit_size_ * sizeof(T));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
}
}
void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ScatterNdUpdate needs 3 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ScatterNdUpdate needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <unordered_map>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class ScatterNdUpdateCPUKernel : public CPUKernel {
public:
ScatterNdUpdateCPUKernel() = default;
~ScatterNdUpdateCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void Check(const CNodePtr &kernel_node);
TypeId dtype_;
int unit_size_;
int num_units_;
int indices_unit_rank_;
std::vector<size_t> shape_;
std::vector<int> output_unit_offsets_;
std::vector<int> out_strides_;
};
MS_REG_CPU_KERNEL(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_

View File

@ -0,0 +1,104 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=True)
class ScatterNdUpdate1(nn.Cell):
def __init__(self):
super(ScatterNdUpdate1, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
class ScatterNdUpdate2(nn.Cell):
def __init__(self):
super(ScatterNdUpdate2, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
class ScatterNdUpdate3(nn.Cell):
def __init__(self):
super(ScatterNdUpdate3, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op1():
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
scatter_nd_update = ScatterNdUpdate1()
output = scatter_nd_update(indices, update)
print("output:\n", output)
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
assert np.allclose(output.asnumpy(), np.array(expect, np.float))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op2():
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
scatter_nd_update = ScatterNdUpdate2()
output = scatter_nd_update(indices, update)
print("output:\n", output)
expect = [1, 11, 3, 10, 9, 6, 7, 12]
assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op3():
indices = Tensor(np.array([[0], [2]]), mstype.int32)
update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
scatter_nd_update = ScatterNdUpdate3()
output = scatter_nd_update(indices, update)
print("output:\n", output)
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))