forked from mindspore-Ecosystem/mindspore
!43724 add list diff aicpu ops
Merge pull request !43724 from yangsijia/listdiff-aicpu
This commit is contained in:
commit
692771c1a5
|
@ -39,6 +39,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/concat_offset_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/list_diff_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random_shuffle_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/range_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/reservoir_replay_buffer.cc
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_ops/list_diff_kernel.h"
|
||||
#include <unordered_set>
|
||||
#include "proto/aicpu_tensor.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
constexpr size_t kDim0 = 0;
|
||||
constexpr size_t kDim1 = 1;
|
||||
constexpr size_t kDim2 = 2;
|
||||
constexpr size_t kDim3 = 3;
|
||||
constexpr size_t kListDiffAddressSize = 4;
|
||||
|
||||
#define REG_LIST_DIFF_TYPE(data_type, type) \
|
||||
case (data_type): { \
|
||||
if (idx_type_ == aicpuops::DataType::MS_INT64) { \
|
||||
return ListDiffTask<type, int64_t>(); \
|
||||
} else { \
|
||||
return ListDiffTask<type, int32_t>(); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Launch Kernel
|
||||
template <typename T, typename Tidx>
|
||||
uint32_t ListDiffKernel::ListDiffTask() {
|
||||
if (io_addrs_.size() != kListDiffAddressSize) {
|
||||
AICPU_LOGE("ListDiffKernel's address is invalid");
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
|
||||
auto x_addr = reinterpret_cast<T *>(io_addrs_[kDim0]);
|
||||
auto y_addr = reinterpret_cast<T *>(io_addrs_[kDim1]);
|
||||
auto out_addr = reinterpret_cast<T *>(io_addrs_[kDim2]);
|
||||
auto idx_addr = reinterpret_cast<Tidx *>(io_addrs_[kDim3]);
|
||||
|
||||
std::unordered_set<T> y_set;
|
||||
y_set.reserve(y_size_);
|
||||
for (int64_t i = 0; i < y_size_; ++i) {
|
||||
(void)y_set.insert(y_addr[i]);
|
||||
}
|
||||
|
||||
// calculate results
|
||||
out_size_ = 0;
|
||||
for (Tidx i = 0; i < static_cast<Tidx>(x_size_); ++i) {
|
||||
if (y_set.count(x_addr[i]) == 0) {
|
||||
out_addr[out_size_] = x_addr[i];
|
||||
idx_addr[out_size_] = i;
|
||||
++out_size_;
|
||||
}
|
||||
}
|
||||
|
||||
// update out
|
||||
if (output_shape_and_type_.size() < kDim2) {
|
||||
AICPU_LOGE("ListDiffKernel's output size is invalid");
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
|
||||
output_shape_and_type_[kDim0]->dims[kDim0] = out_size_; // output
|
||||
output_shape_and_type_[kDim1]->dims[kDim0] = out_size_; // out_idx
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
// Init Kernel
|
||||
uint32_t ListDiffKernel::ParseKernelParam() {
|
||||
idx_type_ = static_cast<aicpuops::DataType>(output_shape_and_type_[1]->type);
|
||||
|
||||
aicpuops::Tensor x = node_def_.inputs(kDim0);
|
||||
input_type_ = static_cast<aicpuops::DataType>(x.tensor_type());
|
||||
x_size_ = 0;
|
||||
const auto &x_shape = x.tensor_shape();
|
||||
for (int i = 0; i < x_shape.dim_size(); ++i) {
|
||||
x_size_ += static_cast<int64_t>(x_shape.dim(i).size());
|
||||
}
|
||||
|
||||
aicpuops::Tensor y = node_def_.inputs(kDim1);
|
||||
const auto &y_shape = y.tensor_shape();
|
||||
y_size_ = 0;
|
||||
for (int i = 0; i < y_shape.dim_size(); ++i) {
|
||||
y_size_ += static_cast<int64_t>(y_shape.dim(i).size());
|
||||
}
|
||||
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
// Get Support Type
|
||||
uint32_t ListDiffKernel::DoCompute() {
|
||||
switch (input_type_) {
|
||||
REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_INT32, int32_t)
|
||||
REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_INT64, int64_t)
|
||||
REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_FLOAT32, float)
|
||||
REG_LIST_DIFF_TYPE(aicpuops::DataType::MS_FLOAT64, double)
|
||||
default:
|
||||
return kAicpuKernelStateInvalid;
|
||||
}
|
||||
}
|
||||
} // namespace aicpu
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t ListDiff(void *param) {
|
||||
aicpu::ListDiffKernel list_diff_kernel;
|
||||
return list_diff_kernel.Compute(param);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_OPS_AICPU_LIST_DIFF_KERNELS_H_
|
||||
#define AICPU_OPS_AICPU_LIST_DIFF_KERNELS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "common/kernel_base.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ListDiffKernel : public KernelBase {
|
||||
public:
|
||||
ListDiffKernel() : KernelBase("ListDiffKernel") {}
|
||||
~ListDiffKernel() = default;
|
||||
|
||||
protected:
|
||||
uint32_t ParseKernelParam() override;
|
||||
uint32_t DoCompute() override;
|
||||
|
||||
template <typename T, typename Tidx>
|
||||
uint32_t ListDiffTask();
|
||||
aicpuops::DataType input_type_{aicpuops::DataType::MS_UNKNOWN};
|
||||
aicpuops::DataType idx_type_{aicpuops::DataType::MS_UNKNOWN};
|
||||
int64_t x_size_;
|
||||
int64_t y_size_;
|
||||
int64_t out_size_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_OPS_AICPU_LIST_DIFF_KERNELS_H_
|
|
@ -177,6 +177,7 @@ constexpr auto kSign = "Sign";
|
|||
constexpr auto kArgmax = "Argmax";
|
||||
constexpr auto kArgmin = "Argmin";
|
||||
constexpr auto kRange = "Range";
|
||||
constexpr auto kListDiff = "ListDiff";
|
||||
|
||||
const std::set<std::string> kCpuKernelOps{kIdentity,
|
||||
kMaskedSelect,
|
||||
|
@ -258,6 +259,7 @@ const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
|
|||
kGatherDGradV2,
|
||||
kConcatOffset,
|
||||
kRandomShuffle,
|
||||
kListDiff,
|
||||
kRange};
|
||||
const std::set<std::string> kDynamicInputOps{kPrint,
|
||||
kPack,
|
||||
|
|
|
@ -45,6 +45,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
kGatherDGradV2OpName,
|
||||
kConcatOffsetOpName,
|
||||
kRandomShuffleOpName,
|
||||
kListDiffOpName,
|
||||
kRangeOpName};
|
||||
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
|
||||
|
||||
|
|
|
@ -164,3 +164,4 @@ from .reservoir_replay_buffer import _rrb_destroy_op_cpu
|
|||
from .parallel_concat import _parallel_concat_aicpu
|
||||
from .concat_offset import _concat_offset_aicpu
|
||||
from .range import _range_aicpu
|
||||
from .list_diff import _list_diff_aicpu
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import array_ops as op
|
||||
|
||||
|
||||
class ListDiffNet(nn.Cell):
|
||||
def __init__(self, out_idx=mstype.int32):
|
||||
super(ListDiffNet, self).__init__()
|
||||
self.list_diff_op = op.ListDiff(out_idx=out_idx)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.list_diff_op(x, y)
|
||||
|
||||
|
||||
def run_case(out_idx, is_dynamic):
|
||||
np.random.seed(1024)
|
||||
dtype = mstype.int32
|
||||
x = Tensor(np.arange(1, 7, 1), dtype=mstype.int32) # [1, 2, 3, 4, 5, 6]
|
||||
y = Tensor([1, 3, 5], dtype=mstype.int32)
|
||||
|
||||
net = ListDiffNet(out_idx)
|
||||
if is_dynamic:
|
||||
dyn_shape = [None,]
|
||||
x0_dyn = Tensor(shape=dyn_shape, dtype=dtype)
|
||||
x1_dyn = Tensor(shape=dyn_shape, dtype=dtype)
|
||||
net.set_inputs(x0_dyn, x1_dyn)
|
||||
|
||||
ms_out = net(x, y)
|
||||
|
||||
out_idx_np_type = np.int32 if out_idx == mstype.int32 else np.int64
|
||||
expect = (np.array([2, 4, 6]).astype(np.int32), np.array([1, 3, 5]).astype(out_idx_np_type))
|
||||
assert all(list(map(lambda x, y: np.allclose(x.asnumpy(), y), ms_out, expect)))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_diff_int32():
|
||||
"""
|
||||
Feature: test ListDiff op on Ascend.
|
||||
Description: test the ListDiff when input is int32.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
run_case(mstype.int32, False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
run_case(mstype.int32, False)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_diff_int32_dyn():
|
||||
"""
|
||||
Feature: test ListDiff op on Ascend.
|
||||
Description: test the ListDiff when input is int32 dynamic shape.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
run_case(mstype.int32, True)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
run_case(mstype.int32, True)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_diff_int64():
|
||||
"""
|
||||
Feature: test ListDiff op on Ascend.
|
||||
Description: test the ListDiff when input is int64.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
run_case(mstype.int64, False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
run_case(mstype.int64, False)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_diff_int64_dyn():
|
||||
"""
|
||||
Feature: test ListDiff op on Ascend.
|
||||
Description: test the ListDiff when input is int64 dynamic shape.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
run_case(mstype.int64, True)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
run_case(mstype.int64, True)
|
Loading…
Reference in New Issue