diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/CMakeLists.txt b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/CMakeLists.txt index 70b67414de7..1b027bf3976 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/CMakeLists.txt +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/CMakeLists.txt @@ -39,6 +39,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER}) ${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer.cc ${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer_kernels.cc ${CMAKE_CURRENT_SOURCE_DIR}/concat_offset_kernel.cc + ${CMAKE_CURRENT_SOURCE_DIR}/list_diff_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/random_shuffle_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/range_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/reservoir_replay_buffer.cc diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.cc new file mode 100644 index 00000000000..aaec889ef1e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.h" +#include +#include "proto/aicpu_tensor.pb.h" + +namespace aicpu { +namespace { +constexpr size_t kDim0 = 0; +constexpr size_t kDim1 = 1; +constexpr size_t kDim2 = 2; +constexpr size_t kDim3 = 3; +constexpr size_t kListDiffAddressSize = 4; + +#define REG_LIST_DIFF_TYPE(data_type, type) \ + case (data_type): { \ + if (idx_type_ == aicpuops::DataType::MS_INT64) { \ + return ListDiffTask(); \ + } else { \ + return ListDiffTask(); \ + } \ + break; \ + } +} // namespace + +// Launch Kernel +template +uint32_t ListDiffKernel::ListDiffTask() { + if (io_addrs_.size() != kListDiffAddressSize) { + AICPU_LOGE("ListDiffKernel's address is invalid"); + return kAicpuKernelStateFailed; + } + + auto x_addr = reinterpret_cast(io_addrs_[kDim0]); + auto y_addr = reinterpret_cast(io_addrs_[kDim1]); + auto out_addr = reinterpret_cast(io_addrs_[kDim2]); + auto idx_addr = reinterpret_cast(io_addrs_[kDim3]); + + std::unordered_set y_set; + y_set.reserve(y_size_); + for (int64_t i = 0; i < y_size_; ++i) { + (void)y_set.insert(y_addr[i]); + } + + // calculate results + out_size_ = 0; + for (Tidx i = 0; i < static_cast(x_size_); ++i) { + if (y_set.count(x_addr[i]) == 0) { + out_addr[out_size_] = x_addr[i]; + idx_addr[out_size_] = i; + ++out_size_; + } + } + + // update out + if (output_shape_and_type_.size() < kDim2) { + AICPU_LOGE("ListDiffKernel's output size is invalid"); + return kAicpuKernelStateFailed; + } + + output_shape_and_type_[kDim0]->dims[kDim0] = out_size_; // output + output_shape_and_type_[kDim1]->dims[kDim0] = out_size_; // out_idx + return kAicpuKernelStateSucess; +} + +// Init Kernel +uint32_t ListDiffKernel::ParseKernelParam() { + idx_type_ = static_cast(output_shape_and_type_[1]->type); + + aicpuops::Tensor x = node_def_.inputs(kDim0); + input_type_ = static_cast(x.tensor_type()); + x_size_ = 0; + const auto &x_shape = x.tensor_shape(); + for (int i = 0; i < x_shape.dim_size(); ++i) { + x_size_ += static_cast(x_shape.dim(i).size()); + } + + aicpuops::Tensor y = node_def_.inputs(kDim1); + const auto &y_shape = y.tensor_shape(); + y_size_ = 0; + for (int i = 0; i < y_shape.dim_size(); ++i) { + y_size_ += static_cast(y_shape.dim(i).size()); + } + + return kAicpuKernelStateSucess; +} + +// Get Support Type +uint32_t ListDiffKernel::DoCompute() { + switch (input_type_) { + REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_INT32, int32_t) + REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_INT64, int64_t) + REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_FLOAT32, float) + REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_FLOAT64, double) + default: + return kAicpuKernelStateInvalid; + } +} +} // namespace aicpu + +extern "C" { +__attribute__((visibility("default"))) uint32_t ListDiff(void *param) { + aicpu::ListDiffKernel list_diff_kernel; + return list_diff_kernel.Compute(param); +} +} diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.h new file mode 100644 index 00000000000..fd944ccf86d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_OPS_AICPU_LIST_DIFF_KERNELS_H_ +#define AICPU_OPS_AICPU_LIST_DIFF_KERNELS_H_ + +#include +#include "common/kernel_base.h" + +namespace aicpu { +class ListDiffKernel : public KernelBase { + public: + ListDiffKernel() : KernelBase("ListDiffKernel") {} + ~ListDiffKernel() = default; + + protected: + uint32_t ParseKernelParam() override; + uint32_t DoCompute() override; + + template + uint32_t ListDiffTask(); + aicpuops::DataType input_type_{aicpuops::DataType::MS_UNKNOWN}; + aicpuops::DataType idx_type_{aicpuops::DataType::MS_UNKNOWN}; + int64_t x_size_; + int64_t y_size_; + int64_t out_size_; +}; +} // namespace aicpu +#endif // AICPU_OPS_AICPU_LIST_DIFF_KERNELS_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h index b3e4876a268..c3a72f701e7 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h @@ -177,6 +177,7 @@ constexpr auto kSign = "Sign"; constexpr auto kArgmax = "Argmax"; constexpr auto kArgmin = "Argmin"; constexpr auto kRange = "Range"; +constexpr auto kListDiff = "ListDiff"; const std::set kCpuKernelOps{kIdentity, kMaskedSelect, @@ -258,6 +259,7 @@ const std::set kCpuKernelBaseOps{kRandomChoiceWithMask, kGatherDGradV2, kConcatOffset, kRandomShuffle, + kListDiff, kRange}; const std::set kDynamicInputOps{kPrint, kPack, diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc index cad5755a7d6..59a29b79bac 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc @@ -45,6 +45,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An kGatherDGradV2OpName, kConcatOffsetOpName, kRandomShuffleOpName, + kListDiffOpName, kRangeOpName}; static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels"; diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 912d10cceb5..e58094b94c5 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -164,3 +164,4 @@ from .reservoir_replay_buffer import _rrb_destroy_op_cpu from .parallel_concat import _parallel_concat_aicpu from .concat_offset import _concat_offset_aicpu from .range import _range_aicpu +from .list_diff import _list_diff_aicpu diff --git a/tests/st/ops/ascend/test_list_diff.py b/tests/st/ops/ascend/test_list_diff.py new file mode 100644 index 00000000000..c2f3643346a --- /dev/null +++ b/tests/st/ops/ascend/test_list_diff.py @@ -0,0 +1,119 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import array_ops as op + + +class ListDiffNet(nn.Cell): + def __init__(self, out_idx=mstype.int32): + super(ListDiffNet, self).__init__() + self.list_diff_op = op.ListDiff(out_idx=out_idx) + + def construct(self, x, y): + return self.list_diff_op(x, y) + + +def run_case(out_idx, is_dynamic): + np.random.seed(1024) + dtype = mstype.int32 + x = Tensor(np.arange(1, 7, 1), dtype=mstype.int32) # [1, 2, 3, 4, 5, 6] + y = Tensor([1, 3, 5], dtype=mstype.int32) + + net = ListDiffNet(out_idx) + if is_dynamic: + dyn_shape = [None,] + x0_dyn = Tensor(shape=dyn_shape, dtype=dtype) + x1_dyn = Tensor(shape=dyn_shape, dtype=dtype) + net.set_inputs(x0_dyn, x1_dyn) + + ms_out = net(x, y) + + out_idx_np_type = np.int32 if out_idx == mstype.int32 else np.int64 + expect = (np.array([2, 4, 6]).astype(np.int32), np.array([1, 3, 5]).astype(out_idx_np_type)) + assert all(list(map(lambda x, y: np.allclose(x.asnumpy(), y), ms_out, expect))) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_list_diff_int32(): + """ + Feature: test ListDiff op on Ascend. + Description: test the ListDiff when input is int32. + Expectation: result is right. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + run_case(mstype.int32, False) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + run_case(mstype.int32, False) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_list_diff_int32_dyn(): + """ + Feature: test ListDiff op on Ascend. + Description: test the ListDiff when input is int32 dynamic shape. + Expectation: result is right. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + run_case(mstype.int32, True) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + run_case(mstype.int32, True) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_list_diff_int64(): + """ + Feature: test ListDiff op on Ascend. + Description: test the ListDiff when input is int64. + Expectation: result is right. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + run_case(mstype.int64, False) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + run_case(mstype.int64, False) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_list_diff_int64_dyn(): + """ + Feature: test ListDiff op on Ascend. + Description: test the ListDiff when input is int64 dynamic shape. + Expectation: result is right. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + run_case(mstype.int64, True) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + run_case(mstype.int64, True)