forked from mindspore-Ecosystem/mindspore
migrates aicpu kernels to MS from lqk
This commit is contained in:
parent
2f3d008c2b
commit
389acff921
|
@ -86,4 +86,4 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "variableScope"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constParameter"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "unreadVariable"
|
||||
|
|
|
@ -130,4 +130,5 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "readability/namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "readability/casting"
|
||||
|
|
|
@ -387,6 +387,7 @@ constexpr auto kLSTMOpName = "LSTM";
|
|||
constexpr auto kLuUnpackOpName = "LuUnpack";
|
||||
constexpr auto kMaskedFillOpName = "MaskedFill";
|
||||
constexpr auto kMaskedSelectOpName = "MaskedSelect";
|
||||
constexpr auto kMaskedSelectGradOpName = "MaskedSelectGrad";
|
||||
constexpr auto kMatMulOpName = "MatMul";
|
||||
constexpr auto kMatMulV2OpName = "MatMulV2";
|
||||
constexpr auto kMatrixDiagOpName = "MatrixDiag";
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cpu_kernel/ms_kernel/acos.h"
|
||||
|
||||
#include <unsupported/Eigen/CXX11/Tensor>
|
||||
#include <algorithm>
|
||||
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const std::uint32_t kAcosInputNum{1u};
|
||||
const std::uint32_t kAcosOutputNum{1u};
|
||||
const char *const kAcos{"Acos"};
|
||||
const std::int64_t kAcosParallelNum{64 * 1024};
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
inline T ScalarAcos(const T x) {
|
||||
return std::acos(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Eigen::half ScalarAcos(const Eigen::half x) {
|
||||
const Eigen::half val{static_cast<Eigen::half>(std::acos(static_cast<std::float_t>(x)))};
|
||||
return val;
|
||||
}
|
||||
|
||||
inline std::uint32_t ParallelForAcos(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size,
|
||||
const std::function<void(std::int64_t, std::int64_t)> &work) {
|
||||
if (total > kAcosParallelNum)
|
||||
return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work);
|
||||
else
|
||||
work(0, total);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::uint32_t ComputeAcosKernel(const CpuKernelContext &ctx) {
|
||||
T *input0{static_cast<T *>(ctx.Input(0)->GetData())};
|
||||
T *output{static_cast<T *>(ctx.Output(0)->GetData())};
|
||||
std::int64_t total{ctx.Input(0)->NumElements()};
|
||||
std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)};
|
||||
std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)};
|
||||
return ParallelForAcos(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) {
|
||||
std::transform(input0 + begin, input0 + end, output + begin, ScalarAcos<T>);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::uint32_t ComputeAcos(const CpuKernelContext &ctx) {
|
||||
std::uint32_t result{ComputeAcosKernel<T>(ctx)};
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Acos compute failed.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline std::uint32_t ExtraCheckAcos(const CpuKernelContext &ctx) {
|
||||
if (ctx.Input(0)->GetData() == nullptr) {
|
||||
KERNEL_LOG_ERROR("Get input data failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (ctx.Output(0)->GetData() == nullptr) {
|
||||
KERNEL_LOG_ERROR("Get output data failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (ctx.Input(0)->GetDataType() != ctx.Output(0)->GetDataType()) {
|
||||
KERNEL_LOG_ERROR("The data type of the input [%s] need be the same as the output [%s].",
|
||||
DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The data size of the input [%llu] need be the same as the output "
|
||||
"[%llu].",
|
||||
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
inline std::uint32_t CheckAcos(const CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) {
|
||||
return NormalCheck(ctx, inputs_num, outputs_num) ? KERNEL_STATUS_PARAM_INVALID : ExtraCheckAcos(ctx);
|
||||
}
|
||||
|
||||
inline std::uint32_t ComputeAcos(const CpuKernelContext &ctx) {
|
||||
DataType input_type{ctx.Input(0)->GetDataType()};
|
||||
switch (input_type) {
|
||||
case DT_FLOAT16:
|
||||
return ComputeAcos<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return ComputeAcos<std::float_t>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return ComputeAcos<std::double_t>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::uint32_t AcosCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
return detail::CheckAcos(ctx, kAcosInputNum, kAcosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : detail::ComputeAcos(ctx);
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kAcos, AcosCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cache_swap_table.h"
|
||||
#include <securec.h>
|
||||
#include <map>
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "utils/sparse_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const char *const kCacheSwapTable = "CacheSwapTable";
|
||||
}
|
||||
|
||||
namespace aicpu {
|
||||
template <typename T>
|
||||
uint32_t CacheSwapTableTask(std::vector<Tensor *> &inputs, std::vector<Tensor *> &outputs, int64_t batch_size,
|
||||
int64_t output_size, int64_t one_line_col, int type_size) {
|
||||
if (inputs.size() == 0 || outputs.size() == 0) {
|
||||
KERNEL_LOG_ERROR("CacheSwapTable input or output is empty.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
char *cache_table = reinterpret_cast<char *>(inputs[0]->GetData());
|
||||
T *swap_cache_idx = reinterpret_cast<T *>(inputs[1]->GetData());
|
||||
uint64_t swap_cache_idx_size = inputs[1]->GetDataSize();
|
||||
char *miss_value = reinterpret_cast<char *>(inputs[2]->GetData());
|
||||
|
||||
char *old_value = reinterpret_cast<char *>(outputs[0]->GetData());
|
||||
|
||||
errno_t ret = memset_s(old_value, static_cast<size_t>(output_size * type_size), 0x00,
|
||||
static_cast<size_t>(output_size * type_size));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("Memset failed, result[%d]", ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
uint64_t single_copy_size = static_cast<uint64_t>(type_size * one_line_col);
|
||||
|
||||
if (swap_cache_idx_size < static_cast<uint64_t>(batch_size)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The value of swap_cache_idx_size:[%llu] must be less than "
|
||||
"batch_size:[%lld]",
|
||||
swap_cache_idx_size, batch_size);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
uint64_t old_value_size = outputs[0]->GetDataSize();
|
||||
uint64_t cache_table_size = inputs[0]->GetDataSize();
|
||||
for (int64_t i = 0; i < batch_size; ++i) {
|
||||
if (swap_cache_idx[i] < 0) {
|
||||
continue;
|
||||
}
|
||||
ret = memcpy_s(old_value + i * single_copy_size, old_value_size, cache_table + swap_cache_idx[i] * single_copy_size,
|
||||
single_copy_size);
|
||||
old_value_size -= single_copy_size;
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("CacheSwapTable memcpy failed, result [%d].", ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
ret = memcpy_s(cache_table + swap_cache_idx[i] * single_copy_size, cache_table_size,
|
||||
miss_value + i * single_copy_size, single_copy_size);
|
||||
cache_table_size -= single_copy_size;
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("CacheSwapTable memcpy failed, result [%d].", ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CacheSwapTableMsCpuKernel::DoCompute() {
|
||||
std::map<int, std::function<uint32_t(std::vector<Tensor *> &, std::vector<Tensor *> &, int64_t &, int64_t &,
|
||||
int64_t &, int &)>>
|
||||
calls;
|
||||
calls[DT_INT32] = CacheSwapTableTask<int32_t>;
|
||||
calls[DT_INT64] = CacheSwapTableTask<int64_t>;
|
||||
|
||||
if (calls.find(indices_type_) == calls.end()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"CacheSwapTableMsCpuKernel op doesn't support indices tensor types: "
|
||||
"[%s]",
|
||||
DTypeStr(indices_type_).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int type_size = GetSizeByDataType(param_type_);
|
||||
return calls[indices_type_](inputs_, outputs_, batch_size_, output_size_, one_line_col_, type_size);
|
||||
}
|
||||
|
||||
uint32_t CacheSwapTableMsCpuKernel::GetInputAndCheck(const CpuKernelContext &ctx) {
|
||||
KERNEL_LOG_INFO("GetInputAndCheck start!");
|
||||
// get input Tensors
|
||||
const uint32_t kNumInput = 3;
|
||||
for (uint32_t i = 0; i < kNumInput; ++i) {
|
||||
Tensor *tensor = ctx.Input(i);
|
||||
KERNEL_CHECK_NULLPTR(tensor, KERNEL_STATUS_PARAM_INVALID, "Get input tensor[%d] failed", i)
|
||||
inputs_.push_back(tensor);
|
||||
}
|
||||
// get output Tensors
|
||||
const uint32_t kNumOutput = 1;
|
||||
for (uint32_t i = 0; i < kNumOutput; ++i) {
|
||||
Tensor *tensor = ctx.Output(i);
|
||||
KERNEL_CHECK_NULLPTR(tensor, KERNEL_STATUS_PARAM_INVALID, "Get output tensor[%d] failed", i)
|
||||
outputs_.push_back(tensor);
|
||||
}
|
||||
// get param type
|
||||
param_type_ = static_cast<DataType>(inputs_[0]->GetDataType());
|
||||
indices_type_ = static_cast<DataType>(inputs_[1]->GetDataType());
|
||||
KERNEL_LOG_INFO("GetInputAndCheck success!");
|
||||
|
||||
std::shared_ptr<TensorShape> cache_table_shape = ctx.Input(0)->GetTensorShape();
|
||||
std::shared_ptr<TensorShape> indices_shape = ctx.Input(1)->GetTensorShape();
|
||||
|
||||
for (int32_t i = 1; i < cache_table_shape->GetDims(); ++i) {
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(one_line_col_, cache_table_shape->GetDimSize(i), one_line_col_,
|
||||
KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
for (int32_t i = 0; i < indices_shape->GetDims(); ++i) {
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(batch_size_, indices_shape->GetDimSize(i), batch_size_, KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
output_size_ = batch_size_ * one_line_col_;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CacheSwapTableMsCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t res = GetInputAndCheck(ctx);
|
||||
if (res != KERNEL_STATUS_OK) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = DoCompute();
|
||||
if (res != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Compute failed");
|
||||
return res;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kCacheSwapTable, CacheSwapTableMsCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CACHE_SWAP_TABLE_H
|
||||
#define AICPU_KERNELS_NORMALIZED_CACHE_SWAP_TABLE_H
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CacheSwapTableMsCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~CacheSwapTableMsCpuKernel() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t DoCompute();
|
||||
|
||||
uint32_t GetInputAndCheck(const CpuKernelContext &ctx);
|
||||
|
||||
int64_t batch_size_ = 1;
|
||||
int64_t one_line_col_ = 1;
|
||||
int64_t output_size_ = 1;
|
||||
|
||||
std::vector<Tensor *> inputs_;
|
||||
std::vector<Tensor *> outputs_;
|
||||
DataType param_type_ = DT_FLOAT;
|
||||
DataType indices_type_ = DT_INT32;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "fill.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const char *const kFill = "Fill";
|
||||
}
|
||||
|
||||
namespace aicpu {
|
||||
template <typename T>
|
||||
void FillGenerateCase(Tensor *&value_tensor, Tensor *&output) {
|
||||
auto value = *(reinterpret_cast<T *>(value_tensor->GetData()));
|
||||
if (AddrAlignedCheck(output->GetData())) {
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 1>, Eigen::Aligned> eigen_output(static_cast<T *>(output->GetData()),
|
||||
output->GetTensorShape()->NumElements());
|
||||
eigen_output.setConstant(value);
|
||||
} else {
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 1>, Eigen::Unaligned> eigen_output(static_cast<T *>(output->GetData()),
|
||||
output->GetTensorShape()->NumElements());
|
||||
eigen_output.setConstant(value);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t FillCpuKernel::GetDimsByType(const CpuKernelContext &ctx) {
|
||||
dims.clear();
|
||||
Tensor *dims_tensor = ctx.Input(0);
|
||||
KERNEL_CHECK_NULLPTR(dims_tensor, KERNEL_STATUS_PARAM_INVALID, "Get dims input failed")
|
||||
uint32_t ret;
|
||||
auto dims_dtype = dims_tensor->GetDataType();
|
||||
switch (dims_dtype) {
|
||||
case (DT_INT32):
|
||||
ret = CalcDims<int32_t>(dims_tensor, dims);
|
||||
break;
|
||||
case (DT_INT64):
|
||||
ret = CalcDims<int64_t>(dims_tensor, dims);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR(
|
||||
"Fill kernel dims data_type [%u] not support, support data_types: "
|
||||
"DT_INT32, DT_INT64",
|
||||
dims_dtype);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Fill kernel calculate dims failed");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint32_t FillCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t check = GetDimsByType(ctx);
|
||||
if (check != KERNEL_STATUS_OK) {
|
||||
return check;
|
||||
}
|
||||
Tensor *value_tensor = ctx.Input(1);
|
||||
KERNEL_CHECK_NULLPTR(value_tensor, KERNEL_STATUS_PARAM_INVALID, "Get value input failed")
|
||||
KERNEL_CHECK_NULLPTR(value_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get value input data failed")
|
||||
KERNEL_CHECK_NULLPTR(value_tensor->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get value input shape failed")
|
||||
if (!value_tensor->GetTensorShape()->GetDimSizes().empty()) {
|
||||
KERNEL_LOG_ERROR("Fill kernel value input is not a scalar.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
Tensor *output = ctx.Output(0);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output failed")
|
||||
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
|
||||
KERNEL_CHECK_NULLPTR(output->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get output shape failed")
|
||||
if (output->GetTensorShape()->GetDimSizes() != dims) {
|
||||
KERNEL_LOG_ERROR("Fill kernel output shape not matched.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto input_dtype = value_tensor->GetDataType();
|
||||
auto output_dtype = output->GetDataType();
|
||||
if (input_dtype != output_dtype) {
|
||||
KERNEL_LOG_ERROR("Fill kernel data type not matched, value input dtype [%u], output dtype [%u].", input_dtype,
|
||||
output_dtype);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::map<int, std::function<void(Tensor *&, Tensor *&)>> calls;
|
||||
calls[DT_INT8] = FillGenerateCase<int8_t>;
|
||||
calls[DT_UINT8] = FillGenerateCase<uint8_t>;
|
||||
calls[DT_INT16] = FillGenerateCase<int16_t>;
|
||||
calls[DT_UINT16] = FillGenerateCase<uint16_t>;
|
||||
calls[DT_INT32] = FillGenerateCase<int32_t>;
|
||||
calls[DT_UINT32] = FillGenerateCase<uint32_t>;
|
||||
calls[DT_INT64] = FillGenerateCase<int64_t>;
|
||||
calls[DT_UINT64] = FillGenerateCase<uint64_t>;
|
||||
calls[DT_BOOL] = FillGenerateCase<bool>;
|
||||
calls[DT_FLOAT16] = FillGenerateCase<Eigen::half>;
|
||||
calls[DT_FLOAT] = FillGenerateCase<float>;
|
||||
calls[DT_DOUBLE] = FillGenerateCase<double>;
|
||||
|
||||
if (calls.find(output_dtype) == calls.end()) {
|
||||
KERNEL_LOG_ERROR("Fill kernel data type [%u] not support", output_dtype);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
calls[output_dtype](value_tensor, output);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t FillCpuKernel::CalcDims(const Tensor *dims_tensor, std::vector<int64_t> &dim_vec) {
|
||||
uint64_t data_num = dims_tensor->GetDataSize() / sizeof(T);
|
||||
if (data_num == 0) {
|
||||
KERNEL_LOG_INFO("Fill kernel: dims is empty, fill scalar output.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
KERNEL_CHECK_NULLPTR(dims_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get dims data failed")
|
||||
for (uint64_t i = 0; i < data_num; i++) {
|
||||
auto dim = *(reinterpret_cast<const T *>(dims_tensor->GetData()) + i);
|
||||
if (dim < 0) {
|
||||
KERNEL_LOG_ERROR("Fill kernel: input dim [%llu] is negative, value=[%lld]", i, static_cast<int64_t>(dim));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
// zero dim is different from empty dim.
|
||||
if (dim == 0) {
|
||||
KERNEL_LOG_INFO("Fill kernel: input dim [%llu] is zero", i);
|
||||
}
|
||||
dim_vec.emplace_back(dim);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kFill, FillCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_FILL_H
|
||||
#define AICPU_KERNELS_NORMALIZED_FILL_H
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class FillCpuKernel : public CpuKernel {
|
||||
public:
|
||||
FillCpuKernel() = default;
|
||||
~FillCpuKernel() override = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t GetDimsByType(const CpuKernelContext &ctx);
|
||||
/**
|
||||
* @brief calc dims from input dims tensor
|
||||
* @param dims_tensor input dims tensor
|
||||
* @param dims output shape dims
|
||||
* @return status if success
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t CalcDims(const Tensor *dims_tensor, std::vector<int64_t> &dims);
|
||||
|
||||
std::vector<int64_t> dims;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_KERNELS_NORMALIZED_FILL_H_
|
|
@ -0,0 +1,293 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "masked_select.h"
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "Eigen/Core"
|
||||
#include "securec.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "utils/broadcast_iterator.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
constexpr uint32_t kMaskedSelectInputNum = 2;
|
||||
constexpr uint32_t kMaskedSelectOutputNum = 1;
|
||||
constexpr int64_t kParallelDataNums = 32 * 1000;
|
||||
const char *const kMaskedSelect = "MaskedSelect";
|
||||
struct OutputInfo {
|
||||
int64_t startIdx;
|
||||
int64_t len;
|
||||
OutputInfo() {
|
||||
startIdx = 0;
|
||||
len = 0;
|
||||
}
|
||||
};
|
||||
|
||||
bool CompareFunc(const OutputInfo &a, const OutputInfo &b) { return a.startIdx <= b.startIdx; }
|
||||
|
||||
// calculate the index stride of dataShape.
|
||||
// dataShape:[m, 1, k] and broadcastShape:[j, m, n, k] --> index_stride:[0, k, 0, 1]
|
||||
std::vector<int64_t> CalIndexStride(const std::vector<int64_t> &dataShape, const std::vector<int64_t> &broadcastShape) {
|
||||
int broadcastDimNum = broadcastShape.size();
|
||||
int dataDimNum = dataShape.size();
|
||||
int diffDimNum = broadcastDimNum - dataDimNum;
|
||||
std::vector<int64_t> indexStride(broadcastDimNum, 0);
|
||||
indexStride[broadcastDimNum - 1] = 1;
|
||||
for (int i = broadcastDimNum - 1; i > diffDimNum; i--) {
|
||||
indexStride[i - 1] = indexStride[i] * dataShape[i];
|
||||
}
|
||||
for (int i = 0; i < dataDimNum; i++) {
|
||||
if (dataShape[i] == 1) {
|
||||
indexStride[i + diffDimNum] = 0;
|
||||
}
|
||||
}
|
||||
return indexStride;
|
||||
}
|
||||
|
||||
// calculate the index stride of shape.
|
||||
// shape:[m, n, k] --> index_stride:[n*k, k, 1]
|
||||
std::vector<int64_t> CalIndexStride(const std::vector<int64_t> &shape) {
|
||||
int dimNum = shape.size();
|
||||
std::vector<int64_t> indexStride(dimNum, 1);
|
||||
for (int i = dimNum - 1; i > 0; i--) {
|
||||
indexStride[i - 1] = indexStride[i] * shape[i];
|
||||
}
|
||||
return indexStride;
|
||||
}
|
||||
|
||||
// calculate the original index of data.
|
||||
// shape:[7,8,9] indexStride:[72,9,1] and flatten_index:11--> ori_index:[0,1,2]
|
||||
bool CalIndexInfo(const std::vector<int64_t> &indexStride, int64_t flattenIndex, std::vector<int64_t> &oriIndex,
|
||||
int dimNum) {
|
||||
for (int i = 0; i < dimNum - 1; i++) {
|
||||
if (indexStride[i] == 0) {
|
||||
return false;
|
||||
}
|
||||
oriIndex[i] = flattenIndex / indexStride[i];
|
||||
flattenIndex = flattenIndex % indexStride[i];
|
||||
}
|
||||
oriIndex[dimNum - 1] = flattenIndex;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline int64_t CalFlattenIndex(const std::vector<int64_t> &indexStride, const std::vector<int64_t> &oriIndex,
|
||||
int dimNum) {
|
||||
int64_t flattenIndex = 0;
|
||||
for (int i = 0; i < dimNum; i++) {
|
||||
flattenIndex += indexStride[i] * oriIndex[i];
|
||||
}
|
||||
return flattenIndex;
|
||||
}
|
||||
|
||||
void UpdateIndexByCarry(std::vector<int64_t> &preIndex, const std::vector<int64_t> &shape, int dimNum) {
|
||||
// shape:[7,3,10,17] and last index:[0,0,9,16] -> next index:[0,1,0,0]
|
||||
constexpr int64_t carryBit = 1;
|
||||
for (int i = dimNum - 1; i >= 0; i--) {
|
||||
preIndex[i] = preIndex[i] + carryBit;
|
||||
if (preIndex[i] < shape[i]) {
|
||||
break;
|
||||
} else {
|
||||
preIndex[i] = preIndex[i] - shape[i];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MaskedSelectCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kMaskedSelectInputNum, kMaskedSelectOutputNum), "[%s] check params failed.",
|
||||
kMaskedSelect);
|
||||
|
||||
// choose compute function depend on dataType
|
||||
auto data_type0 = static_cast<DataType>(ctx.Input(kFirstInputIndex)->GetDataType());
|
||||
auto data_type1 = static_cast<DataType>(ctx.Input(kSecondInputIndex)->GetDataType());
|
||||
auto data_type2 = static_cast<DataType>(ctx.Output(kFirstOutputIndex)->GetDataType());
|
||||
if (data_type1 != DT_BOOL) {
|
||||
KERNEL_LOG_ERROR("[%s] Data type of mask requires bool, but got data type [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(data_type1).c_str());
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
if (data_type0 != data_type2) {
|
||||
KERNEL_LOG_ERROR("[%s] Data type of x and y requires same, but got data type [%s] and [%s].",
|
||||
ctx.GetOpType().c_str(), DTypeStr(data_type0).c_str(), DTypeStr(data_type2).c_str());
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
switch (data_type0) {
|
||||
case DT_FLOAT16:
|
||||
return MaskedSelectCompute<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return MaskedSelectCompute<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return MaskedSelectCompute<double>(ctx);
|
||||
case DT_INT8:
|
||||
return MaskedSelectCompute<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return MaskedSelectCompute<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return MaskedSelectCompute<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return MaskedSelectCompute<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return MaskedSelectCompute<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return MaskedSelectCompute<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return MaskedSelectCompute<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return MaskedSelectCompute<uint64_t>(ctx);
|
||||
case DT_BOOL:
|
||||
return MaskedSelectCompute<bool>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not support, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(data_type0).c_str());
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectCpuKernel::ParallelCompute(const CpuKernelContext &ctx, const std::vector<int64_t> &inputShapeX,
|
||||
const std::vector<int64_t> &inputShapeMask,
|
||||
const std::vector<int64_t> &outputShape, int64_t dataNum) {
|
||||
T *x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
bool *mask = reinterpret_cast<bool *>(ctx.Input(1)->GetData());
|
||||
T *y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
||||
std::atomic<int> threadNum{0};
|
||||
std::atomic<bool> taskFlag(true);
|
||||
constexpr int queueLen = 100;
|
||||
std::array<OutputInfo, queueLen> outputIndexList;
|
||||
|
||||
std::vector<int64_t> indexStrideX = CalIndexStride(inputShapeX, outputShape);
|
||||
std::vector<int64_t> indexStrideMask = CalIndexStride(inputShapeMask, outputShape);
|
||||
std::vector<int64_t> indexStrideOutput = CalIndexStride(outputShape);
|
||||
KERNEL_LOG_DEBUG("index stride of x[%s].", VectorToString(indexStrideX).c_str());
|
||||
KERNEL_LOG_DEBUG("index stride of mask[%s].", VectorToString(indexStrideMask).c_str());
|
||||
|
||||
auto work = [=, &threadNum, &taskFlag, &outputIndexList](int64_t start, int64_t end) {
|
||||
int64_t cnt = 0;
|
||||
int dimNum = outputShape.size();
|
||||
std::vector<int64_t> indexValue(dimNum, 0);
|
||||
if (!CalIndexInfo(indexStrideOutput, start, indexValue, dimNum)) {
|
||||
taskFlag.store(false);
|
||||
KERNEL_LOG_ERROR("Invalid index stride, please check.");
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
int64_t maskFlatIndex = CalFlattenIndex(indexStrideMask, indexValue, dimNum);
|
||||
int64_t xFlatIndex = CalFlattenIndex(indexStrideX, indexValue, dimNum);
|
||||
if (mask[maskFlatIndex]) {
|
||||
y[start + cnt] = x[xFlatIndex];
|
||||
cnt++;
|
||||
}
|
||||
UpdateIndexByCarry(indexValue, outputShape, dimNum);
|
||||
}
|
||||
int idx = threadNum.fetch_add(1, std::memory_order_relaxed);
|
||||
if (idx >= queueLen) {
|
||||
taskFlag.store(false);
|
||||
return;
|
||||
}
|
||||
outputIndexList[idx].startIdx = start;
|
||||
outputIndexList[idx].len = cnt;
|
||||
KERNEL_LOG_DEBUG("outputIndexList[%d] startIdx is [%lld], len is [%lld].", idx, outputIndexList[idx].startIdx,
|
||||
outputIndexList[idx].len);
|
||||
};
|
||||
constexpr int perUnitSize = 1000;
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dataNum, perUnitSize, work), "MaskedSelect calculate failed.");
|
||||
|
||||
if (!taskFlag.load()) {
|
||||
KERNEL_LOG_ERROR("Invalid array.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
|
||||
int validNum = threadNum.load();
|
||||
std::sort(outputIndexList.begin(), outputIndexList.begin() + validNum, CompareFunc);
|
||||
|
||||
int validOffset = outputIndexList[0].len;
|
||||
int64_t copyLen = 0;
|
||||
int ret = 0;
|
||||
for (int i = 1; i < validNum; i++) {
|
||||
copyLen = outputIndexList[i].len;
|
||||
if (copyLen <= 0) {
|
||||
continue;
|
||||
}
|
||||
int64_t byteLen = copyLen * static_cast<int64_t>(sizeof(T));
|
||||
ret = memmove_s(y + validOffset, byteLen, y + outputIndexList[i].startIdx, byteLen);
|
||||
KERNEL_CHECK_FALSE((ret == EOK), KERNEL_STATUS_PARAM_INVALID, "Memmove failed, result = [%d].", ret);
|
||||
validOffset += copyLen;
|
||||
}
|
||||
ctx.Output(0)->GetTensorShape()->SetDimSizes({validOffset});
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectCpuKernel::MaskedSelectCompute(const CpuKernelContext &ctx) {
|
||||
T *x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(x, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get input_data[0] failed.",
|
||||
kMaskedSelect);
|
||||
bool *mask = reinterpret_cast<bool *>(ctx.Input(1)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(mask, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get input_data[1] failed.",
|
||||
kMaskedSelect);
|
||||
T *y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(y, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get output_data[0] failed.",
|
||||
kMaskedSelect);
|
||||
|
||||
auto input_shape_a = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_shape_b = ctx.Input(1)->GetTensorShape()->GetDimSizes();
|
||||
if (IsScalar(input_shape_a) && IsScalar(input_shape_b)) {
|
||||
if (mask[0]) {
|
||||
y[0] = x[0];
|
||||
ctx.Output(0)->GetTensorShape()->SetDimSizes({1});
|
||||
} else {
|
||||
ctx.Output(0)->GetTensorShape()->SetDimSizes({0});
|
||||
}
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
std::vector<int64_t> output_shape;
|
||||
auto ret = GetBroadcastShape(input_shape_a, input_shape_b, &output_shape);
|
||||
KERNEL_CHECK_FALSE(ret == KERNEL_STATUS_OK, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID),
|
||||
"Shape of x and mask can't be broadcast.");
|
||||
int64_t tensor_size = 1;
|
||||
for (const int64_t &d : output_shape) {
|
||||
tensor_size *= d;
|
||||
}
|
||||
|
||||
if (tensor_size >= kParallelDataNums) {
|
||||
ret = ParallelCompute<T>(ctx, input_shape_a, input_shape_b, output_shape, tensor_size);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int64_t j = 0;
|
||||
BroadcastIterator iter(input_shape_a, input_shape_b, &output_shape);
|
||||
iter.SetPos(0);
|
||||
for (int64_t i = 0; i < tensor_size; ++i) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
y[j++] = x[iter.GetInputPosA()];
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
ctx.Output(0)->GetTensorShape()->SetDimSizes({j});
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kMaskedSelect, MaskedSelectCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_MASKED_SELECT_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_MASKED_SELECT_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class MaskedSelectCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~MaskedSelectCpuKernel() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief compute for all types
|
||||
* @param ctx cpu kernel context
|
||||
* @return status if success
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectCompute(const CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t ParallelCompute(const CpuKernelContext &ctx, const std::vector<int64_t> &inputShapeX,
|
||||
const std::vector<int64_t> &inputShapeMask, const std::vector<int64_t> &outputShape,
|
||||
int64_t dataNum);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "masked_select_grad.h"
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "securec.h"
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "utils/broadcast_iterator.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
constexpr uint32_t kMaskedSelectGradInputNum = 3;
|
||||
constexpr uint32_t kMaskedSelectGradOutputNum = 1;
|
||||
const char *const kMaskedSelectGrad = "MaskedSelectGrad";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MaskedSelectGradCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kMaskedSelectGradInputNum, kMaskedSelectGradOutputNum),
|
||||
"[%s] check params failed.", kMaskedSelectGrad);
|
||||
|
||||
// choose compute function depend on dataType
|
||||
auto data_type0 = static_cast<DataType>(ctx.Input(kFirstInputIndex)->GetDataType());
|
||||
auto data_type1 = static_cast<DataType>(ctx.Input(kSecondInputIndex)->GetDataType());
|
||||
auto data_type2 = static_cast<DataType>(ctx.Input(2)->GetDataType());
|
||||
if (data_type1 != DT_BOOL) {
|
||||
KERNEL_LOG_ERROR("[%s] Data type of mask requires bool, but got data type [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(data_type1).c_str());
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
if (data_type0 != data_type2) {
|
||||
KERNEL_LOG_ERROR("[%s] Data type of x and y requires same, but got data type [%s] and [%s].",
|
||||
ctx.GetOpType().c_str(), DTypeStr(data_type0).c_str(), DTypeStr(data_type2).c_str());
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
switch (data_type0) {
|
||||
case DT_FLOAT16:
|
||||
return MaskedSelectGradCompute<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return MaskedSelectGradCompute<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return MaskedSelectGradCompute<double>(ctx);
|
||||
case DT_INT8:
|
||||
return MaskedSelectGradCompute<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return MaskedSelectGradCompute<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return MaskedSelectGradCompute<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return MaskedSelectGradCompute<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return MaskedSelectGradCompute<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return MaskedSelectGradCompute<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return MaskedSelectGradCompute<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return MaskedSelectGradCompute<uint64_t>(ctx);
|
||||
case DT_BOOL:
|
||||
return MaskedSelectGradCompute<bool>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not support, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(data_type0).c_str());
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectGradCpuKernel::MaskedSelectGradCompute(const CpuKernelContext &ctx) {
|
||||
bool *mask = reinterpret_cast<bool *>(ctx.Input(1)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(mask, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get input_data[1] failed.",
|
||||
kMaskedSelectGrad);
|
||||
T *grad = reinterpret_cast<T *>(ctx.Input(2)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(grad, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get input_data[2] failed.",
|
||||
kMaskedSelectGrad);
|
||||
T *dx = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(dx, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get output_data[0] failed.",
|
||||
kMaskedSelectGrad);
|
||||
|
||||
auto input_shape_a = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_shape_b = ctx.Input(1)->GetTensorShape()->GetDimSizes();
|
||||
std::vector<int64_t> output_shape;
|
||||
auto ret = GetBroadcastShape(input_shape_a, input_shape_b, &output_shape);
|
||||
KERNEL_CHECK_FALSE(ret == KERNEL_STATUS_OK, KERNEL_STATUS_PARAM_INVALID, "Shape of x and mask can't be broadcast.");
|
||||
int64_t tensor_size = 1;
|
||||
for (const int64_t &d : output_shape) {
|
||||
tensor_size *= d;
|
||||
}
|
||||
const T NUM_ZERO = static_cast<T>(0);
|
||||
for (int k = 0; k < tensor_size; ++k) {
|
||||
dx[k] = NUM_ZERO;
|
||||
}
|
||||
int64_t j = 0;
|
||||
BroadcastIterator iter(input_shape_a, input_shape_b, &output_shape);
|
||||
iter.SetPos(0);
|
||||
for (int64_t i = 0; i < tensor_size; ++i) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
dx[iter.GetInputPosA()] += grad[j++];
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kMaskedSelectGrad, MaskedSelectGradCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,16 +13,25 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_MASKED_SELECT_GRAD_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_MASKED_SELECT_GRAD_H_
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_ACOS_H
|
||||
#define AICPU_KERNELS_NORMALIZED_ACOS_H
|
||||
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AcosCpuKernel final : public CpuKernel {
|
||||
class MaskedSelectGradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
std::uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
~MaskedSelectGradCpuKernel() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief compute for all types
|
||||
* @param ctx cpu kernel context
|
||||
* @return status if success
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectGradCompute(const CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nms_with_mask.h"
|
||||
#include <numeric>
|
||||
#include "Eigen/Core"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const int32_t kInputNum = 1;
|
||||
const int32_t kOutputNum = 3;
|
||||
const int kColNum5 = 5;
|
||||
const int kColNum8 = 8;
|
||||
const char *kNMSWithMask = "NMSWithMask";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t NMSWithMaskCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
// check param
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "NMSWithMask check input or output is failed");
|
||||
AttrValue *iou_threshold = ctx.GetAttr("iou_threshold");
|
||||
KERNEL_CHECK_FALSE((iou_threshold != nullptr), KERNEL_STATUS_PARAM_INVALID, "Get attr [iou_threshold] failed.");
|
||||
iou_value_ = iou_threshold->GetFloat();
|
||||
|
||||
Tensor *input_data = ctx.Input(0);
|
||||
auto data_type = input_data->GetDataType();
|
||||
KERNEL_CHECK_FALSE((data_type == DT_FLOAT || data_type == DT_FLOAT16), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input[0] data type[%s] is unsupported", DTypeStr(data_type).c_str());
|
||||
auto input_shape = input_data->GetTensorShape()->GetDimSizes();
|
||||
num_input_ = input_shape[0]; // Get N values in [N, 5] data.
|
||||
box_size_ = input_shape[1];
|
||||
if (box_size_ != kColNum5 && box_size_ != kColNum8) {
|
||||
KERNEL_LOG_INFO("NMSWithMask the col number of input[0] must be [%d] or [%d], but got [%d]!", kColNum5, kColNum8,
|
||||
box_size_);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
uint32_t res;
|
||||
switch (data_type) {
|
||||
case DT_FLOAT16:
|
||||
res = DoCompute<Eigen::half>(ctx);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
res = DoCompute<float>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_INFO("NMSWithMask input[0] only support type[DT_FLOAT16, DT_FLOAT], but got type[%s]",
|
||||
DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
break;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t NMSWithMaskCpuKernel::DoCompute(const CpuKernelContext &ctx) {
|
||||
auto input = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output = reinterpret_cast<T *>(ctx.Output(OUTPUT)->GetData());
|
||||
auto sel_idx = reinterpret_cast<int *>(ctx.Output(SEL_IDX)->GetData());
|
||||
auto sel_boxes = reinterpret_cast<bool *>(ctx.Output(SEL_BOXES)->GetData());
|
||||
std::fill(&sel_idx[0], &sel_idx[num_input_], 0);
|
||||
std::fill(&sel_boxes[0], &sel_boxes[num_input_], false);
|
||||
|
||||
const int box_size = box_size_;
|
||||
const auto comp = [input, box_size](const size_t a, const size_t b) {
|
||||
const size_t index_a = a * box_size + 4;
|
||||
const size_t index_b = b * box_size + 4;
|
||||
if (input[index_b] == input[index_a]) {
|
||||
return a < b;
|
||||
};
|
||||
return input[index_b] < input[index_a];
|
||||
};
|
||||
std::vector<int> order(num_input_);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(), comp);
|
||||
|
||||
std::vector<T> areas(num_input_);
|
||||
for (int64_t i = 0; i < num_input_; i++) {
|
||||
areas[i] =
|
||||
(input[i * box_size_ + 2] - input[i * box_size_]) * (input[i * box_size_ + 3] - input[i * box_size_ + 1]);
|
||||
}
|
||||
|
||||
int64_t num_to_keep = 0;
|
||||
for (int64_t _i = 0; _i < num_input_; _i++) {
|
||||
auto i = order[_i];
|
||||
if (sel_boxes[i] == 1) continue;
|
||||
sel_idx[num_to_keep++] = i;
|
||||
auto ix1 = input[i * box_size_];
|
||||
auto iy1 = input[i * box_size_ + 1];
|
||||
auto ix2 = input[i * box_size_ + 2];
|
||||
auto iy2 = input[i * box_size_ + 3];
|
||||
|
||||
for (int64_t _j = _i + 1; _j < num_input_; _j++) {
|
||||
auto j = order[_j];
|
||||
if (sel_boxes[j] == 1) continue;
|
||||
auto xx1 = std::max(ix1, input[j * box_size_]);
|
||||
auto yy1 = std::max(iy1, input[j * box_size_ + 1]);
|
||||
auto xx2 = std::min(ix2, input[j * box_size_ + 2]);
|
||||
auto yy2 = std::min(iy2, input[j * box_size_ + 3]);
|
||||
|
||||
auto w = std::max(static_cast<T>(0), xx2 - xx1);
|
||||
auto h = std::max(static_cast<T>(0), yy2 - yy1);
|
||||
auto inter = w * h;
|
||||
auto ovr = inter / (areas[i] + areas[j] - inter);
|
||||
if (static_cast<float>(ovr) > iou_value_) {
|
||||
sel_boxes[j] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < num_input_; ++k) {
|
||||
for (int j = 0; j < box_size_; ++j) {
|
||||
if (k < num_to_keep) {
|
||||
output[k * kColNum5 + j] = input[sel_idx[k] * box_size_ + j];
|
||||
sel_boxes[k] = true;
|
||||
} else {
|
||||
output[k * kColNum5 + j] = static_cast<T>(0);
|
||||
sel_boxes[k] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kNMSWithMask, NMSWithMaskCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_NMS_WITH_MASK_H
|
||||
#define AICPU_KERNELS_NORMALIZED_NMS_WITH_MASK_H
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "kernel_log.h"
|
||||
#include "securec.h"
|
||||
|
||||
namespace aicpu {
|
||||
class NMSWithMaskCpuKernel : public CpuKernel {
|
||||
public:
|
||||
NMSWithMaskCpuKernel() = default;
|
||||
~NMSWithMaskCpuKernel() override = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t DoCompute(const CpuKernelContext &ctx);
|
||||
|
||||
int num_input_{0};
|
||||
float iou_value_{0.0};
|
||||
size_t ceil_power_2{0};
|
||||
int box_size_ = 5; // pre_defined box width
|
||||
enum output_list_ { OUTPUT, SEL_IDX, SEL_BOXES };
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,265 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "reduce_sum.h"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kReduceSumInputNum = 2;
|
||||
const uint32_t kReduceSumOutputNum = 1;
|
||||
const char *const kReduceSum = "ReduceSum";
|
||||
#define REDUCESUM_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = ReduceSumCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("ReduceSum kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
#define REDUCESUM_COMPUTE_CASE_COMPLEX(DTYPE, TYPE, IN_TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = ReduceSumCompute2<TYPE, IN_TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("ReduceSum kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ReduceSumCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kReduceSumInputNum, kReduceSumOutputNum), "[%s] check input and output failed.",
|
||||
kReduceSum);
|
||||
KERNEL_HANDLE_ERROR(ReduceSumCheck(ctx), "[%s] check params failed.", kReduceSum);
|
||||
auto input_data_type = ctx.Input(0)->GetDataType();
|
||||
switch (input_data_type) {
|
||||
REDUCESUM_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
REDUCESUM_COMPUTE_CASE_COMPLEX(DT_COMPLEX64, std::complex<float>, float, ctx)
|
||||
REDUCESUM_COMPUTE_CASE_COMPLEX(DT_COMPLEX128, std::complex<double>, double, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("ReduceSum kernel data type [%s] not support.", DTypeStr(input_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCheck(const CpuKernelContext &ctx) const {
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get output failed.");
|
||||
if (ctx.Input(1)->GetData() != nullptr) {
|
||||
KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_INT32 || ctx.Input(1)->GetDataType() == DT_INT64),
|
||||
KERNEL_STATUS_PARAM_INVALID, "Data type of axis is not support, axis data type is [%u].",
|
||||
ctx.Input(1)->GetDataType());
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute(const CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
if (input_shape.size() == 0) {
|
||||
output_data[0] = input_data[0];
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
if (axes_data == nullptr) {
|
||||
int64_t data_num = ctx.Input(0)->NumElements();
|
||||
auto accumulator = static_cast<T>(0);
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
accumulator += input_data[i];
|
||||
}
|
||||
output_data[0] = accumulator;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::vector<int64_t> axes;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumDedupAxes(ctx, axes), "ReduceSum deduplicate failed.");
|
||||
int64_t output_num = ctx.Output(0)->NumElements();
|
||||
uint32_t axes_idx = 0;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumOneAxes<T>(input_data, input_shape, output_data, output_num, axes, axes_idx),
|
||||
"Reduce sum compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumOneAxes(const T *input_data, std::vector<int64_t> &input_shape, T *output_data,
|
||||
int64_t output_num, std::vector<int64_t> &axes, uint32_t &axes_idx) {
|
||||
if (axes_idx >= axes.size()) {
|
||||
for (int64_t i = 0; i < output_num; i++) {
|
||||
output_data[i] = input_data[i];
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
int64_t inner = 1, outer = 1, depth = 1;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumParseAxes(input_shape, axes, axes_idx, inner, outer, depth), "parse axes failed.");
|
||||
auto output_data_temp = new (std::nothrow) T[inner * outer];
|
||||
KERNEL_CHECK_NULLPTR(output_data_temp, KERNEL_STATUS_INNER_ERROR, "apply memory failed.");
|
||||
for (int64_t outer_index = 0; outer_index < outer; ++outer_index) {
|
||||
for (int64_t inner_index = 0; inner_index < inner; inner_index++) {
|
||||
auto accumulator = static_cast<T>(0);
|
||||
for (int64_t depth_index = 0; depth_index < depth; depth_index++) {
|
||||
int64_t index = outer_index;
|
||||
index += depth_index * outer;
|
||||
index += inner_index * depth * outer;
|
||||
accumulator += input_data[index];
|
||||
}
|
||||
int64_t output_index = outer_index;
|
||||
output_index += inner_index * outer;
|
||||
output_data_temp[output_index] = accumulator;
|
||||
}
|
||||
}
|
||||
uint32_t result = ReduceSumOneAxes<T>(output_data_temp, input_shape, output_data, output_num, axes, axes_idx);
|
||||
if (output_data_temp != nullptr) {
|
||||
delete[] output_data_temp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(const CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
if (input_shape.size() == 0) {
|
||||
output_data[0] = std::complex<T2>(input_data[0].real(), input_data[0].imag());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
int64_t input_num = ctx.Input(0)->NumElements();
|
||||
if (axes_data == nullptr) {
|
||||
auto accumulator_real = static_cast<T2>(0);
|
||||
auto accumulator_imag = static_cast<T2>(0);
|
||||
for (int64_t i = 0; i < input_num; i++) {
|
||||
accumulator_real += input_data[i].real();
|
||||
accumulator_imag += input_data[i].imag();
|
||||
}
|
||||
output_data[0] = std::complex<T2>(accumulator_real, accumulator_imag);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::vector<int64_t> axes;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumDedupAxes(ctx, axes), "ReduceSum deduplicate failed.");
|
||||
int64_t output_num = ctx.Output(0)->NumElements();
|
||||
uint32_t axes_idx = 0;
|
||||
KERNEL_HANDLE_ERROR(
|
||||
(ReduceSumOneAxes2<T, T2>(input_data, input_num, input_shape, output_data, output_num, axes, axes_idx)),
|
||||
"Reduce sum compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumOneAxes2(const T *input_data, int64_t input_num, std::vector<int64_t> input_shape,
|
||||
T *output_data, int64_t output_num, std::vector<int64_t> &axes,
|
||||
uint32_t &axes_idx) {
|
||||
if (axes_idx >= axes.size()) {
|
||||
auto accumulator_real = static_cast<T2>(0);
|
||||
auto accumulator_imag = static_cast<T2>(0);
|
||||
for (int64_t i = 0; i < output_num; i++) {
|
||||
accumulator_real = input_data[i].real();
|
||||
accumulator_imag = input_data[i].imag();
|
||||
output_data[i] = std::complex<T2>(accumulator_real, accumulator_imag);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
int64_t inner = 1, outer = 1, depth = 1;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumParseAxes(input_shape, axes, axes_idx, inner, outer, depth), "parse axes failed.");
|
||||
std::vector<T2> input_data_real(input_num);
|
||||
std::vector<T2> input_data_imag(input_num);
|
||||
for (int64_t i = 0; i < input_num; i++) {
|
||||
input_data_real[i] = input_data[i].real();
|
||||
input_data_imag[i] = input_data[i].imag();
|
||||
}
|
||||
int64_t output_num_temp = inner * outer;
|
||||
auto *output_data_temp = new (std::nothrow) T[output_num_temp];
|
||||
KERNEL_CHECK_NULLPTR(output_data_temp, KERNEL_STATUS_INNER_ERROR, "apply memory failed.");
|
||||
for (int64_t outer_index = 0; outer_index < outer; outer_index++) {
|
||||
for (int64_t inner_index = 0; inner_index < inner; inner_index++) {
|
||||
auto accumulator_real = static_cast<T2>(0);
|
||||
auto accumulator_imag = static_cast<T2>(0);
|
||||
for (int64_t depth_index = 0; depth_index < depth; depth_index++) {
|
||||
int64_t index = outer_index;
|
||||
index += inner_index * depth * outer;
|
||||
index += depth_index * outer;
|
||||
accumulator_real += input_data_real[index];
|
||||
accumulator_imag += input_data_imag[index];
|
||||
}
|
||||
int64_t output_index = outer_index;
|
||||
output_index += inner_index * outer;
|
||||
output_data_temp[output_index] = std::complex<T2>(accumulator_real, accumulator_imag);
|
||||
}
|
||||
}
|
||||
uint32_t result =
|
||||
ReduceSumOneAxes2<T, T2>(output_data_temp, output_num_temp, input_shape, output_data, output_num, axes, axes_idx);
|
||||
if (output_data_temp != nullptr) {
|
||||
delete[] output_data_temp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(const CpuKernelContext &ctx, std::vector<int64_t> &axes) {
|
||||
int32_t rank = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
int64_t axes_num = ctx.Input(1)->NumElements();
|
||||
for (int64_t i = 0; i < axes_num; i++) {
|
||||
int32_t axis = axes_data[i];
|
||||
KERNEL_CHECK_FALSE((axis < rank) && (axis >= -rank), KERNEL_STATUS_PARAM_INVALID,
|
||||
"axes[%d] is out of input dims rank[%d]", axis, rank);
|
||||
if (axis < 0) {
|
||||
axis += rank;
|
||||
}
|
||||
axes.push_back(axis);
|
||||
}
|
||||
int64_t j = 1;
|
||||
while (j < axes_num) {
|
||||
std::vector<int64_t>::iterator iter = find(axes.begin(), axes.begin() + j, axes[j]);
|
||||
if (iter != axes.begin() + j) {
|
||||
axes.erase(iter);
|
||||
axes_num--;
|
||||
} else {
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes,
|
||||
uint32_t &axes_idx, int64_t &inner, int64_t &outer,
|
||||
int64_t &depth) const {
|
||||
int64_t axis = axes[axes_idx];
|
||||
axes_idx++;
|
||||
int64_t rank = input_shape.size();
|
||||
for (int64_t i = 0; i < rank; i++) {
|
||||
if (i < axis) {
|
||||
inner *= input_shape[i];
|
||||
} else if (i > axis) {
|
||||
outer *= input_shape[i];
|
||||
} else {
|
||||
depth = input_shape[i];
|
||||
input_shape[i] = 1;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kReduceSum, ReduceSumCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_REDUCE_SUM_H
|
||||
#define AICPU_KERNELS_NORMALIZED_REDUCE_SUM_H
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ReduceSumCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ReduceSumCpuKernel() = default;
|
||||
~ReduceSumCpuKernel() override = default;
|
||||
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ReduceSumCheck(const CpuKernelContext &ctx) const;
|
||||
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCompute(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t ReduceSumOneAxes(const T *input_data, std::vector<int64_t> &input_shape, T *output_data, int64_t output_num,
|
||||
std::vector<int64_t> &axes, uint32_t &axes_idx);
|
||||
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCompute2(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumOneAxes2(const T *input_data, int64_t input_num, std::vector<int64_t> input_shape, T *output_data,
|
||||
int64_t output_num, std::vector<int64_t> &axes, uint32_t &axes_idx);
|
||||
|
||||
uint32_t ReduceSumDedupAxes(const CpuKernelContext &ctx, std::vector<int64_t> &axes);
|
||||
|
||||
uint32_t ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes, uint32_t &axes_idx,
|
||||
int64_t &inner, int64_t &outer, int64_t &depth) const;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sparse_group.h"
|
||||
|
||||
namespace aicpu {
|
||||
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
|
||||
++next_loc_;
|
||||
const auto &ix_t = iter_->ix_matrix_;
|
||||
const int64_t N = ix_t.dimension(0);
|
||||
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
|
||||
++next_loc_;
|
||||
}
|
||||
}
|
||||
|
||||
bool GroupIterable::IteratorStep::operator!=(const IteratorStep &rhs) const { return (rhs.loc_ != loc_); }
|
||||
|
||||
bool GroupIterable::IteratorStep::operator==(const IteratorStep &rhs) const { return (rhs.loc_ == loc_); }
|
||||
|
||||
GroupIterable::IteratorStep &GroupIterable::IteratorStep::operator++() { // prefix ++
|
||||
loc_ = next_loc_;
|
||||
UpdateEndOfGroup();
|
||||
return *this;
|
||||
}
|
||||
|
||||
const GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(int) // postfix ++
|
||||
{
|
||||
IteratorStep lhs(*this);
|
||||
++(*this);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
Group GroupIterable::IteratorStep::operator*() const { return Group(iter_, loc_, next_loc_); }
|
||||
|
||||
std::vector<int64_t> Group::group() const {
|
||||
std::vector<int64_t> g;
|
||||
const auto &ix_t = iter_->ix_matrix_;
|
||||
for (const int64_t d : iter_->group_dims_) {
|
||||
g.push_back(ix_t(loc_, d));
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
TTypes<int64_t>::UnalignedConstMatrix Group::indices() const {
|
||||
return TTypes<int64_t>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)), next_loc_ - loc_, iter_->dims_);
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNEL_UTIL_SPARSE_GROUP_ITERATOR_H_
|
||||
#define CPU_KERNEL_UTIL_SPARSE_GROUP_ITERATOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include "eigen_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
class Group; // Predeclare Group for GroupIterable.
|
||||
|
||||
// ///////////////
|
||||
// GroupIterable
|
||||
// ///////////////
|
||||
//
|
||||
// Returned when calling sparse_tensor.group({dim0, dim1, ...}).
|
||||
//
|
||||
// Please note: the sparse_tensor should already be ordered according
|
||||
// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
|
||||
//
|
||||
// Allows grouping and iteration of the SparseTensor according to the
|
||||
// subset of dimensions provided to the group call.
|
||||
//
|
||||
// The actual grouping dimensions are stored in the
|
||||
// internal vector group_dims_. Iterators inside the iterable provide
|
||||
// the three methods:
|
||||
//
|
||||
// * group(): returns a vector with the current group dimension values.
|
||||
// * indices(): a map of index, providing the indices in
|
||||
// this group.
|
||||
// * values(): a map of values, providing the values in
|
||||
// this group.
|
||||
//
|
||||
// To iterate across GroupIterable, see examples in README.md.
|
||||
//
|
||||
|
||||
// Forward declaration of SparseTensor
|
||||
class GroupIterable {
|
||||
public:
|
||||
using VarDimArray = std::vector<int64_t>;
|
||||
|
||||
GroupIterable(Tensor *ix, Tensor *vals, int dims, const VarDimArray &group_dims)
|
||||
: ix_(ix),
|
||||
ix_matrix_(EigenTensor(ix, ix->GetData()).matrix<int64_t>()),
|
||||
vals_(vals),
|
||||
dims_(dims),
|
||||
group_dims_(group_dims.begin(), group_dims.end()) {}
|
||||
|
||||
~GroupIterable() {}
|
||||
|
||||
class IteratorStep;
|
||||
|
||||
IteratorStep begin() { return IteratorStep(this, 0); }
|
||||
|
||||
IteratorStep at(int64_t loc) {
|
||||
if (!(loc >= 0 && loc <= static_cast<int64_t>(ix_->GetTensorShape()->GetDimSize(0)))) {
|
||||
KERNEL_LOG_WARN("loc should in [0, %d], but got: %d", ix_->GetTensorShape()->GetDimSize(0), loc);
|
||||
}
|
||||
return IteratorStep(this, loc);
|
||||
}
|
||||
|
||||
IteratorStep end() { return IteratorStep(this, ix_->GetTensorShape()->GetDimSize(0)); }
|
||||
|
||||
template <typename TIX>
|
||||
inline bool GroupMatches(const TIX &ix, int64_t loc_a, int64_t loc_b) const {
|
||||
for (int64_t d : group_dims_) {
|
||||
if (ix(loc_a, d) != ix(loc_b, d)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
class IteratorStep {
|
||||
public:
|
||||
IteratorStep(GroupIterable *iter, int64_t loc) : iter_(iter), loc_(loc), next_loc_(loc_) { UpdateEndOfGroup(); }
|
||||
|
||||
~IteratorStep() { iter_ = nullptr; }
|
||||
|
||||
void UpdateEndOfGroup();
|
||||
|
||||
bool operator!=(const IteratorStep &rhs) const;
|
||||
|
||||
bool operator==(const IteratorStep &rhs) const;
|
||||
|
||||
IteratorStep &operator++();
|
||||
|
||||
const IteratorStep operator++(int);
|
||||
|
||||
Group operator*() const;
|
||||
|
||||
int64_t loc() const { return loc_; }
|
||||
|
||||
private:
|
||||
GroupIterable *iter_;
|
||||
int64_t loc_;
|
||||
int64_t next_loc_;
|
||||
};
|
||||
|
||||
private:
|
||||
friend class Group;
|
||||
Tensor *ix_;
|
||||
TTypes<int64_t>::Matrix ix_matrix_;
|
||||
Tensor *vals_;
|
||||
const int dims_;
|
||||
const std::vector<int64_t> group_dims_;
|
||||
};
|
||||
|
||||
// This class is returned when dereferencing a GroupIterable iterator.
|
||||
// It provides the methods group(), indices(), and values(), which
|
||||
// provide access into the underlying SparseTensor.
|
||||
class Group {
|
||||
public:
|
||||
Group(GroupIterable *iter, int64_t loc, int64_t next_loc) : iter_(iter), loc_(loc), next_loc_(next_loc) {}
|
||||
|
||||
~Group() { iter_ = NULL; }
|
||||
|
||||
std::vector<int64_t> group() const;
|
||||
|
||||
TTypes<int64_t>::UnalignedConstMatrix indices() const;
|
||||
|
||||
int64_t group_at(size_t index) const {
|
||||
const auto &ix_t = iter_->ix_matrix_;
|
||||
return ix_t(loc_, index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedVec values() const {
|
||||
return typename TTypes<T>::UnalignedVec(&(EigenTensor(iter_->vals_, iter_->vals_->GetData()).vec<T>()(loc_)),
|
||||
next_loc_ - loc_);
|
||||
}
|
||||
|
||||
private:
|
||||
GroupIterable *iter_;
|
||||
int64_t loc_;
|
||||
int64_t next_loc_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // CPU_KERNEL_UTIL_SPARSE_GROUP_ITERATOR_H_
|
|
@ -0,0 +1,128 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sparse_tensor.h"
|
||||
#include "cpu_types.h"
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t SparseTensor::CreateSparseTensor(Tensor *ix, Tensor *tensorvals, std::vector<int64_t> shape,
|
||||
std::vector<int64_t> order) {
|
||||
KERNEL_LOG_INFO("Start to execute CreateSparseTensor.");
|
||||
if (ix == nullptr || ix->GetData() == nullptr) {
|
||||
KERNEL_LOG_ERROR("Ix is nullptr.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (tensorvals == nullptr || tensorvals->GetData() == nullptr) {
|
||||
KERNEL_LOG_ERROR("Vals is nullptr.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
if (ix->GetTensorShape()->GetDims() > 2) {
|
||||
KERNEL_LOG_ERROR("Index tensor dim size less than 2 or equal to 2, got size [%d] ",
|
||||
ix->GetTensorShape()->GetDims());
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
int64_t dims = (ix->GetTensorShape()->GetDims() == 0) ? 1 : ix->GetTensorShape()->GetDimSize(0);
|
||||
int64_t vals_dim0 = (tensorvals->GetTensorShape()->GetDims() == 0) ? 1 : tensorvals->GetTensorShape()->GetDimSize(0);
|
||||
if (dims != vals_dim0) {
|
||||
KERNEL_LOG_ERROR("Ix dim_size_0 [%ld] != tensorvals dim_size_0 [%ld]", dims, vals_dim0);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
dims = ix->GetTensorShape()->GetDims() == 2 ? ix->GetTensorShape()->GetDimSize(1) : 1;
|
||||
int64_t orderSize = static_cast<int64_t>(order.size());
|
||||
int64_t shapeSize = static_cast<int64_t>(shape.size());
|
||||
if (orderSize != dims) {
|
||||
KERNEL_LOG_ERROR("orderSize [%ld] != dims [%ld]", orderSize, dims);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (shapeSize != dims) {
|
||||
KERNEL_LOG_ERROR("shapeSize [%ld] != dims [%ld]", shapeSize, dims);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
ix_ = std::make_shared<EigenTensor>(ix, ix->GetData());
|
||||
vals_ = std::make_shared<EigenTensor>(tensorvals, tensorvals->GetData());
|
||||
if (ix_ == nullptr || vals_ == nullptr) {
|
||||
KERNEL_LOG_ERROR("Indices or values create eigen tensor failed.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
shape_.assign(shape.begin(), shape.end());
|
||||
order_.assign(order.begin(), order.end());
|
||||
dims_ = static_cast<int32_t>(dims);
|
||||
KERNEL_LOG_INFO("Execute CreateSparseTensor end");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t SparseTensor::IndicesValid(CpuKernelContext &ctx) const {
|
||||
if (std::any_of(order_.begin(), order_.end(), [](int64_t ord) { return ord < 0; })) {
|
||||
KERNEL_LOG_ERROR("Order was not provided.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (ix_->GetTensor()->GetDataType() == DT_INT32) {
|
||||
if (EigenTensorIndicesValid<int32_t>(ctx) != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Indices valid failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
} else {
|
||||
if (EigenTensorIndicesValid<int64_t>(ctx) != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Indices valid failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
bool SparseTensor::ValidateToDense(const Tensor *out) const {
|
||||
KERNEL_LOG_INFO("Start execute ValidateToDense.");
|
||||
if (out->GetDataType() != vals_->GetTensor()->GetDataType()) {
|
||||
KERNEL_LOG_ERROR("Output data type must match vals, got out [%d], vals [%d].", out->GetDataType(),
|
||||
vals_->GetTensor()->GetDataType());
|
||||
return false;
|
||||
}
|
||||
if (out->GetTensorShape()->GetDims() != dims_) {
|
||||
KERNEL_LOG_ERROR("Output dims must match idx, got output dims [%d], idx dims [%d].",
|
||||
out->GetTensorShape()->GetDims(), dims_);
|
||||
return false;
|
||||
}
|
||||
const auto out_shape = out->GetTensorShape();
|
||||
int32_t shapeSize = static_cast<int32_t>(shape_.size());
|
||||
if (shapeSize != out_shape->GetDims()) {
|
||||
KERNEL_LOG_ERROR("output dims must match shape dims, got output dim [%d], shape dim [%d].", out_shape->GetDims(),
|
||||
shapeSize);
|
||||
return false;
|
||||
}
|
||||
for (size_t d = 0; d < shape_.size(); ++d) {
|
||||
if (shape_[d] > out_shape->GetDimSize(static_cast<int32_t>(d))) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Valid output shape dims value failed, index [%zu], shape value [%ld], "
|
||||
"greater than output shape value [%d].",
|
||||
d, shape_[d], out_shape->GetDimSize(static_cast<int32_t>(d)));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
KERNEL_LOG_INFO("Execute Validate dense end.");
|
||||
return true;
|
||||
}
|
||||
|
||||
GroupIterable SparseTensor::group(const std::vector<int64_t> &group_ix) const {
|
||||
if (group_ix.size() > static_cast<size_t>(dims_)) {
|
||||
KERNEL_LOG_WARN("Grop_ix.size:%zu > dims_:%d", group_ix.size(), dims_);
|
||||
}
|
||||
return GroupIterable(const_cast<Tensor *>(ix_->GetTensor()), const_cast<Tensor *>(vals_->GetTensor()), dims_,
|
||||
group_ix);
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,296 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_SPARSETENSOR_H
|
||||
#define AICPU_SPARSETENSOR_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "cpu_tensor.h"
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "kernel_log.h"
|
||||
#include "sparse_group.h"
|
||||
#include "status.h"
|
||||
|
||||
namespace aicpu {
|
||||
template <typename T>
|
||||
const T SubtleMustCopy(const T &x) {
|
||||
auto *to_x = reinterpret_cast<const volatile T *>(&x);
|
||||
return *to_x;
|
||||
}
|
||||
} // namespace aicpu
|
||||
|
||||
namespace aicpu {
|
||||
class SparseTensor {
|
||||
public:
|
||||
SparseTensor() : dims_(0) {}
|
||||
~SparseTensor() = default;
|
||||
|
||||
/*
|
||||
* create sparse tensor
|
||||
* @param ix: index tensor
|
||||
* @param tensorvals: tensorvals tensor
|
||||
* @param shape: shape vec
|
||||
* @param order: order vec
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t CreateSparseTensor(Tensor *ix, Tensor *tensorvals, std::vector<int64_t> shape, std::vector<int64_t> order);
|
||||
|
||||
/*
|
||||
* sparse indices valid
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t IndicesValid(CpuKernelContext &ctx) const;
|
||||
|
||||
/*
|
||||
* group sparse tensor
|
||||
* @return GroupIterable
|
||||
*/
|
||||
GroupIterable group(const std::vector<int64_t> &group_ix) const;
|
||||
/*
|
||||
* sparse eigen tensor indices valid
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t EigenTensorIndicesValidCheck(int64_t dims_size) const {
|
||||
const auto ix_t = ix_->matrix<T>();
|
||||
for (int64_t n = 1; n < dims_size; ++n) {
|
||||
bool valid = true;
|
||||
bool different = false;
|
||||
bool increasing = true;
|
||||
for (int32_t di = 0; di < dims_; ++di) {
|
||||
if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) {
|
||||
valid = false;
|
||||
}
|
||||
int64_t diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
|
||||
if (diff > 0) {
|
||||
different = true;
|
||||
}
|
||||
if (!different && diff < 0) {
|
||||
increasing = false;
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
KERNEL_LOG_ERROR("Indices is out of bounds, index=%lld.", n);
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
if (!increasing) {
|
||||
KERNEL_LOG_ERROR("indices is out of order, index=%lld.", n);
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
if (!different) {
|
||||
KERNEL_LOG_ERROR("indices is repeated, index=%lld.", n);
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
}
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
/*
|
||||
* sparse eigen tensor indices valid
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t EigenTensorIndicesValidParaCheck(const CpuKernelContext &ctx, int64_t dims_size) const {
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
uint32_t result = static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
(void)aicpu::CpuKernelUtils::ParallelFor(ctx, dims_size, dims_size / max_core_num,
|
||||
[&](std::int64_t begin, std::int64_t end) {
|
||||
int64_t start = begin;
|
||||
if (begin == 0) {
|
||||
start = begin + 1;
|
||||
}
|
||||
const auto ix_t = ix_->matrix<T>();
|
||||
for (int64_t n = start; n < end; ++n) {
|
||||
bool valid = true;
|
||||
bool different = false;
|
||||
bool increasing = true;
|
||||
for (int32_t di = 0; di < dims_; ++di) {
|
||||
if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) {
|
||||
valid = false;
|
||||
}
|
||||
int64_t diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
|
||||
if (diff > 0) {
|
||||
different = true;
|
||||
}
|
||||
if (!different && diff < 0) {
|
||||
increasing = false;
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
KERNEL_LOG_ERROR("Indices is out of bounds, index=%lld.", n);
|
||||
result = static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
return;
|
||||
}
|
||||
if (!increasing) {
|
||||
KERNEL_LOG_ERROR("indices is out of order, index=%lld.", n);
|
||||
result = static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
return;
|
||||
}
|
||||
if (!different) {
|
||||
KERNEL_LOG_ERROR("indices is repeated, index=%lld.", n);
|
||||
result = static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
/*
|
||||
* sparse eigen tensor indices valid
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t EigenTensorIndicesValid(const CpuKernelContext &ctx) const {
|
||||
const auto ix_t = ix_->matrix<T>();
|
||||
int64_t dims_size =
|
||||
(ix_->GetTensor()->GetTensorShape()->GetDims() == 0) ? 1 : ix_->GetTensor()->GetTensorShape()->GetDimSize(0);
|
||||
if (dims_size > 0) {
|
||||
for (int32_t di = 0; di < dims_; ++di) {
|
||||
if ((ix_t(0, di) < 0) || (ix_t(0, di) >= shape_[di])) {
|
||||
KERNEL_LOG_ERROR("Indices is out of bounds, index=0.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
}
|
||||
const int64_t paralled_data_size = 16 * 1024;
|
||||
if (dims_size < paralled_data_size) {
|
||||
return EigenTensorIndicesValidCheck<T>(dims_size);
|
||||
} else {
|
||||
return EigenTensorIndicesValidParaCheck<T>(ctx, dims_size);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* validate sparse to dense
|
||||
* @param output: output tensor
|
||||
* @return bool: true->success false->failed
|
||||
*/
|
||||
bool ValidateToDense(const Tensor *out) const;
|
||||
|
||||
/*
|
||||
* sparse tensor to dense tensor
|
||||
* @param output: output tensor
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
template <typename IndiceT, typename ValueT>
|
||||
uint32_t ToDenseParallel(const CpuKernelContext &ctx, Tensor *output) {
|
||||
EigenTensor outputET(output, output->GetData());
|
||||
auto output_t = outputET.flat<ValueT>();
|
||||
auto ix_t = ix_->matrix<IndiceT>();
|
||||
std::vector<int64_t> strides(dims_);
|
||||
const auto &out_shape = output->GetTensorShape();
|
||||
if (dims_ > 0) {
|
||||
strides[dims_ - 1] = 1;
|
||||
}
|
||||
for (int32_t d = dims_ - 2; d >= 0; --d) {
|
||||
strides[d] = strides[d + 1] * out_shape->GetDimSize(d + 1);
|
||||
}
|
||||
auto vals_t = vals_->vec<ValueT>();
|
||||
int64_t vals_size = vals_t.dimension(0);
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
uint32_t result = static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
auto parallel_proc = [&](std::int64_t begin, std::int64_t end) {
|
||||
for (int64_t n = begin; n < end; ++n) {
|
||||
bool invalid_dims = false;
|
||||
int64_t ix = 0;
|
||||
for (int d = 0; d < dims_; ++d) {
|
||||
const int64_t ix_n_d = ix_t(n, d);
|
||||
if (ix_n_d > out_shape->GetDimSize(d)) {
|
||||
invalid_dims = true;
|
||||
}
|
||||
ix += strides[d] * ix_n_d;
|
||||
}
|
||||
if (invalid_dims) {
|
||||
result = static_cast<uint32_t>(KERNEL_STATUS_INNER_ERROR);
|
||||
KERNEL_LOG_ERROR("Sparse to dense got invalid dims.");
|
||||
return;
|
||||
}
|
||||
output_t(ix) = vals_t(n);
|
||||
}
|
||||
return;
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(aicpu::CpuKernelUtils::ParallelFor(ctx, vals_size, vals_size / max_core_num, parallel_proc),
|
||||
"SparseToDense Compute failed.");
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* sparse tensor to dense tensor
|
||||
* @param output: output tensor
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
template <typename IndiceT, typename ValueT>
|
||||
uint32_t ToDense(const CpuKernelContext &ctx, Tensor *output) {
|
||||
KERNEL_LOG_INFO("Start to execute ToDense.");
|
||||
if (output == nullptr || output->GetData() == nullptr) {
|
||||
KERNEL_LOG_ERROR("Output tensor is nullptr.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (!ValidateToDense(output)) {
|
||||
KERNEL_LOG_ERROR("Validate to dense param failed.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
auto vals_t = vals_->vec<ValueT>();
|
||||
int64_t vals_size = vals_t.dimension(0);
|
||||
const int64_t paralled_data_size = 16 * 1024;
|
||||
if (vals_size >= paralled_data_size) {
|
||||
return ToDenseParallel<IndiceT, ValueT>(ctx, output);
|
||||
}
|
||||
EigenTensor outputET(output, output->GetData());
|
||||
auto output_t = outputET.flat<ValueT>();
|
||||
auto ix_t = ix_->matrix<IndiceT>();
|
||||
std::vector<int64_t> strides(dims_);
|
||||
const auto &out_shape = output->GetTensorShape();
|
||||
if (dims_ > 0) {
|
||||
strides[dims_ - 1] = 1;
|
||||
}
|
||||
for (int32_t d = dims_ - 2; d >= 0; --d) {
|
||||
strides[d] = strides[d + 1] * out_shape->GetDimSize(d + 1);
|
||||
}
|
||||
for (int64_t n = 0; n < vals_size; ++n) {
|
||||
bool invalid_dims = false;
|
||||
int64_t ix = 0;
|
||||
for (int d = 0; d < dims_; ++d) {
|
||||
const int64_t ix_n_d = ix_t(n, d);
|
||||
if (ix_n_d > out_shape->GetDimSize(d)) {
|
||||
invalid_dims = true;
|
||||
}
|
||||
ix += strides[d] * ix_n_d;
|
||||
}
|
||||
if (invalid_dims) {
|
||||
KERNEL_LOG_ERROR("Sparse to dense got invalid dims.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
output_t(ix) = vals_t(n);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<EigenTensor> ix_;
|
||||
std::shared_ptr<EigenTensor> vals_;
|
||||
std::vector<int64_t> shape_;
|
||||
std::vector<int64_t> order_;
|
||||
int32_t dims_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_SPARSETENSOR_H
|
|
@ -48,14 +48,18 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
kSliceGradOpName,
|
||||
kRandomShuffleOpName,
|
||||
kRangeOpName};
|
||||
static const std::set<std::string> kMigrateAicpuKernelOps = {
|
||||
mindspore::kACosOpName,
|
||||
mindspore::kLogMatrixDeterminantOpName,
|
||||
mindspore::kAdaptiveAvgPool2dOpName,
|
||||
mindspore::kAdaptiveAvgPool2dGradOpName,
|
||||
mindspore::kMedianOpName,
|
||||
mindspore::kMedianGradOpName,
|
||||
};
|
||||
static const std::set<std::string> kMigrateAicpuKernelOps = {mindspore::kACosOpName,
|
||||
mindspore::kAdaptiveAvgPool2dOpName,
|
||||
mindspore::kAdaptiveAvgPool2dGradOpName,
|
||||
mindspore::kCacheSwapTableOpName,
|
||||
mindspore::kFillOpName,
|
||||
mindspore::kLogMatrixDeterminantOpName,
|
||||
mindspore::kMaskedSelectOpName,
|
||||
mindspore::kMaskedSelectGradOpName,
|
||||
mindspore::kMedianOpName,
|
||||
mindspore::kMedianGradOpName,
|
||||
mindspore::kNMSWithMaskOpName,
|
||||
mindspore::kReduceSumOpName};
|
||||
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
|
||||
static const std::string kCpuKernelSoName = "mindspore_cpu_kernels";
|
||||
|
||||
|
|
Loading…
Reference in New Issue