diff --git a/mindspore/ccsrc/device/cpu/kernel/one_hot_cpu_kernel.cc b/mindspore/ccsrc/device/cpu/kernel/one_hot_cpu_kernel.cc new file mode 100644 index 00000000000..e4b3f03f58e --- /dev/null +++ b/mindspore/ccsrc/device/cpu/kernel/one_hot_cpu_kernel.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "device/cpu/kernel/one_hot_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace device { +namespace cpu { +void OneHotCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (output_shape.size() < 2) { + MS_LOG(EXCEPTION) << "invalid output shape size: " << output_shape.size(); + } + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis != -1 && IntToSize(axis) >= output_shape.size()) { + MS_LOG(EXCEPTION) << "invalid axis: " << axis; + } + if (axis == -1) { + axis_ = output_shape.size() - 1; + } else { + axis_ = IntToSize(axis); + } + depth_ = output_shape[axis_]; + stride_ = 1; + for (size_t i = axis_ + 1; i < output_shape.size(); ++i) { + stride_ *= output_shape[i]; + } +} + +bool OneHotCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 3 || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output invalid!"; + } + auto indices = reinterpret_cast(inputs[0]->addr); + auto on_value = reinterpret_cast(inputs[1]->addr)[0]; + auto off_value = reinterpret_cast(inputs[2]->addr)[0]; + auto output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + + for (size_t i = 0; i < elem_num; i++) { + size_t stride_num = i / stride_; + size_t output_index = stride_num * depth_ * stride_ + i % stride_; + size_t index = IntToSize(indices[i]); + for (size_t j = 0; j < depth_; j++) { + if (index == j) { + output[output_index] = on_value; + } else { + output[output_index] = off_value; + } + output_index += stride_; + } + } + + return true; +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/kernel/one_hot_cpu_kernel.h b/mindspore/ccsrc/device/cpu/kernel/one_hot_cpu_kernel.h new file mode 100644 index 00000000000..f41ac63265a --- /dev/null +++ b/mindspore/ccsrc/device/cpu/kernel/one_hot_cpu_kernel.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_ONE_HOT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_ONE_HOT_CPU_KERNEL_H_ +#include +#include +#include "device/cpu/cpu_kernel.h" +#include "device/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace device { +namespace cpu { +class OneHotCPUKernel : public CPUKernel { + public: + OneHotCPUKernel() = default; + ~OneHotCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t depth_; + size_t stride_; + size_t axis_; +}; + +MS_REG_CPU_KERNEL(OneHot, OneHotCPUKernel); +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_ONE_HOT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/device/cpu/kernel/reshape_cpu_kernel.h b/mindspore/ccsrc/device/cpu/kernel/reshape_cpu_kernel.h index 908c3df2d9f..d371e3a7ac2 100644 --- a/mindspore/ccsrc/device/cpu/kernel/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/device/cpu/kernel/reshape_cpu_kernel.h @@ -35,6 +35,8 @@ class ReshapeCPUKernel : public CPUKernel { }; MS_REG_CPU_KERNEL(Reshape, ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Flatten, ReshapeCPUKernel); +MS_REG_CPU_KERNEL(ExpandDims, ReshapeCPUKernel); } // namespace cpu } // namespace device } // namespace mindspore diff --git a/tests/st/ops/cpu/test_one_hot_op.py b/tests/st/ops/cpu/test_one_hot_op.py new file mode 100644 index 00000000000..3f2c54b3cb6 --- /dev/null +++ b/tests/st/ops/cpu/test_one_hot_op.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context + +context.set_context(device_target='CPU') + + +class NetOneHot(nn.Cell): + def __init__(self): + super(NetOneHot, self).__init__() + self.on_value = 2.0 + self.off_value = 3.0 + + self.depth_1 = 6 + self.one_hot_1 = nn.OneHot(-1, self.depth_1, self.on_value, self.off_value) + + self.depth_2 = 4 + self.one_hot_2 = nn.OneHot(0, self.depth_1, self.on_value, self.off_value) + self.one_hot_3 = nn.OneHot(0, self.depth_2, self.on_value, self.off_value) + self.one_hot_4 = nn.OneHot(1, self.depth_1, self.on_value, self.off_value) + + @ms_function + def construct(self, indices1, indices2, indices3, indices4): + return (self.one_hot_1(indices1), self.one_hot_2(indices2), + self.one_hot_3(indices3), self.one_hot_4(indices4)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_one_hot(): + one_hot = NetOneHot() + indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) + indices2 = Tensor(np.array([1, 2, 3]).astype(np.int32)) + indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32)) + indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) + output = one_hot(indices1, indices2, indices3, indices4) + expect_0 = np.array([ + [[2., 3., 3., 3., 3., 3.], [3., 2., 3., 3., 3., 3.]], + [[3., 3., 3., 3., 2., 3.], [3., 3., 3., 3., 3., 2.]], + [[3., 3., 2., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]] + ]).astype(np.float32) + expect_1 = np.array([ + [3., 3., 3.], + [2., 3., 3.], + [3., 2., 3.], + [3., 3., 2.], + [3., 3., 3.], + [3., 3., 3.] + ]).astype(np.float32) + expect_2 = np.array([ + [[2., 3.], [3., 2.]], [[3., 2.], [2., 3.]], [[3., 3.], [3., 3.]], + [[3., 3.], [3., 3.]] + ]).astype(np.float32) + expect_3 = np.array([ + [[2., 3.], [3., 2.], [3., 3.], [3., 3.], [3., 3.], [3., 3.]], + [[3., 3.], [3., 3.], [3., 3.], [3., 3.], [2., 3.], [3., 2.]], + [[3., 3.], [3., 3.], [2., 3.], [3., 3.], [3., 3.], [3., 3.]] + ]).astype(np.float32) + assert (output[0].asnumpy() == expect_0).all() + assert (output[1].asnumpy() == expect_1).all() + assert (output[2].asnumpy() == expect_2).all() + assert (output[3].asnumpy() == expect_3).all()