forked from OSSInnovation/mindspore
!4566 Add ScatterNdUpdate cpu kernel
Merge pull request !4566 from huanghui/scatter-nd-update-cpu-kernel
This commit is contained in:
commit
8c72d5b9d1
|
@ -38,7 +38,6 @@ bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -37,16 +37,22 @@ MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
|
|||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReshapeCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReshapeCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReshapeCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
Check(kernel_node);
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
if (indices_shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Indices' dimension less than 2";
|
||||
}
|
||||
auto indices_unit_rank = indices_shape.back();
|
||||
if (indices_unit_rank > shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Value of last dimension of indices is greater than shape rank";
|
||||
}
|
||||
if (indices_shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Indices dimension less than 2";
|
||||
}
|
||||
if (updates_shape.size() != indices_shape.size() - 1 + shape.size() - indices_unit_rank) {
|
||||
MS_LOG(EXCEPTION) << "Update, shape rank and indices rank inconsistent";
|
||||
}
|
||||
for (size_t i = 0; i < indices_shape.size() - 1; ++i) {
|
||||
if (updates_shape[i] != indices_shape[i]) {
|
||||
MS_LOG(EXCEPTION) << "Value of " << i << "th dimension of indices is not equal to that update";
|
||||
}
|
||||
}
|
||||
indices_unit_rank_ = SizeToInt(indices_unit_rank);
|
||||
unit_size_ = 1;
|
||||
for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) {
|
||||
unit_size_ *= SizeToInt(updates_shape[i]);
|
||||
}
|
||||
num_units_ = 1;
|
||||
num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]);
|
||||
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) {
|
||||
num_units_ *= SizeToInt(updates_shape[i]);
|
||||
}
|
||||
int out_stride = 1;
|
||||
out_strides_.push_back(out_stride);
|
||||
for (int i = indices_unit_rank_ - 2; i >= 0; i--) {
|
||||
out_stride *= shape[i + 1];
|
||||
out_strides_.push_back(out_stride);
|
||||
}
|
||||
shape_ = shape;
|
||||
output_unit_offsets_.reserve(num_units_);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support float16, float32";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
for (int i = 0; i < num_units_; ++i) {
|
||||
int offset = 0;
|
||||
for (int j = 0; j < indices_unit_rank_; ++j) {
|
||||
offset += indices[i * indices_unit_rank_ + j] * out_strides_[j] * unit_size_;
|
||||
}
|
||||
output_unit_offsets_[i] = offset;
|
||||
}
|
||||
|
||||
auto mem_bits = outputs[0]->size;
|
||||
auto ret = memcpy_s(y, mem_bits, x, mem_bits);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_units_; i++) {
|
||||
ret =
|
||||
memcpy_s(y + output_unit_offsets_[i], unit_size_ * sizeof(T), updates + unit_size_ * i, unit_size_ * sizeof(T));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ScatterNdUpdate needs 3 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ScatterNdUpdate needs 1 output.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ScatterNdUpdateCPUKernel : public CPUKernel {
|
||||
public:
|
||||
ScatterNdUpdateCPUKernel() = default;
|
||||
~ScatterNdUpdateCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void Check(const CNodePtr &kernel_node);
|
||||
TypeId dtype_;
|
||||
int unit_size_;
|
||||
int num_units_;
|
||||
int indices_unit_rank_;
|
||||
std::vector<size_t> shape_;
|
||||
std::vector<int> output_unit_offsets_;
|
||||
std::vector<int> out_strides_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(ScatterNdUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ND_UPDATE_CPU_KERNEL_H_
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=True)
|
||||
|
||||
|
||||
class ScatterNdUpdate1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate1, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
|
||||
class ScatterNdUpdate2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate2, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
|
||||
class ScatterNdUpdate3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate3, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op1():
|
||||
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
|
||||
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate1()
|
||||
output = scatter_nd_update(indices, update)
|
||||
print("output:\n", output)
|
||||
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
|
||||
assert np.allclose(output.asnumpy(), np.array(expect, np.float))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op2():
|
||||
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
|
||||
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate2()
|
||||
output = scatter_nd_update(indices, update)
|
||||
print("output:\n", output)
|
||||
expect = [1, 11, 3, 10, 9, 6, 7, 12]
|
||||
assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op3():
|
||||
indices = Tensor(np.array([[0], [2]]), mstype.int32)
|
||||
update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate3()
|
||||
output = scatter_nd_update(indices, update)
|
||||
print("output:\n", output)
|
||||
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||
assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))
|
Loading…
Reference in New Issue