add aicpu_ops

This commit is contained in:
gaoshuanglong 2023-01-05 15:37:54 +08:00
parent 88833bada9
commit f00862beb2
38 changed files with 3360 additions and 2 deletions

View File

@ -86,6 +86,7 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constVariable"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "redundantAssignment"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constArgument"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "unknownMacro"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "nullPointerRedundantCheck"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "variableScope"

View File

@ -132,3 +132,6 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/newline"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/operators"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/comma"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/indentation_namespace"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/line_length"

View File

@ -282,6 +282,9 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d.cc:aicpu::MaxUnpool2DCpuKernel::MaxUnpool2DCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_solve_ls.cc:aicpu::MatrixSolveLsCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/col2im.cc:aicpu::Col2imCpuKernel::Col2imParamCheck
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc:aicpu::CSRSparseMatrixToSparseTensorCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc:aicpu::CumprodCpuKernel::CumprodCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc:aicpu::CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCheck
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd_update.cc:aicpu::ScatterNdUpdateCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/ragged_tensor_to_sparse.cc:aicpu::RaggedTensorToSparseCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_3d_grad.cc:aicpu::MaxUnpool3DGradCpuKernel::MaxUnpool3DGradCompute

View File

@ -149,15 +149,21 @@ constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2";
constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3";
constexpr auto kBpropCutOpName = "bprop_cut";
constexpr auto kBroadcastToOpName = "BroadcastTo";
constexpr auto kBucketizeOpName = "Bucketize";
constexpr auto kIFMROpName = "IFMR";
constexpr auto kBroadcastToDOpName = "BroadcastToD";
constexpr auto kCacheSwapTableOpName = "CacheSwapTable";
constexpr auto kCallOpName = "call";
constexpr auto kCauchyOpName = "Cauchy";
constexpr auto kCastOpName = "Cast";
constexpr auto kCentralizationOpName = "Centralization";
constexpr auto kCeLUOpName = "CeLU";
constexpr auto kCeluV2OpName = "CeluV2";
constexpr auto kChannelShuffleOpName = "ChannelShuffle";
constexpr auto kCheckNumericsOpName = "CheckNumerics";
constexpr auto kCholeskyGradOpName = "CholeskyGrad";
constexpr auto kCholeskyInverseOpName = "CholeskyInverse";
constexpr auto kCholeskySolveOpName = "CholeskySolve";
constexpr auto kClearZeroOpName = "ClearZero";
constexpr auto kClipBoxesOpName = "kClipBoxes";
constexpr auto kClipBoxesDOpName = "kClipBoxesD";
@ -165,8 +171,11 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum";
constexpr auto kClipByValueOpName = "ClipByValue";
constexpr auto kCoalesceOpName = "Coalesce";
constexpr auto kCol2imOpName = "Col2im";
constexpr auto kCombinedNonMaxSuppressionOpName = "CombinedNonMaxSuppression";
constexpr auto kCombineMomentumOpName = "CombineMomentum";
constexpr auto kCombineMomentumWeightOpName = "CombineMomentumWeight";
constexpr auto kComplexOpName = "Complex";
constexpr auto kComplexAbsOpName = "ComplexAbs";
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
constexpr auto kConcatOpName = "Concat";
constexpr auto kConcatDOpName = "ConcatD";
@ -193,6 +202,8 @@ constexpr auto kCropAndResizeOpName = "CropAndResize";
constexpr auto kCropAndResizeDOpName = "CropAndResizeD";
constexpr auto kConvBN1OpName = "ConvBN1";
constexpr auto kCOO2CSROpName = "COO2CSR";
constexpr auto kCosOpName = "Cos";
constexpr auto kCountNonZeroOpName = "CountNonZero";
constexpr auto kCSR2COOOpName = "CSR2COO";
constexpr auto kCSRDivOpName = "CSRDiv";
constexpr auto kCSRGatherOpName = "CSRGather";
@ -200,6 +211,7 @@ constexpr auto kCSRMMOpName = "CSRMM";
constexpr auto kCSRMulOpName = "CSRMul";
constexpr auto kCSRMVOpName = "CSRMV";
constexpr auto kCSRReduceSumOpName = "CSRReduceSum";
constexpr auto kCSRSparseMatrixToDenseOpName = "CSRSparseMatrixToDense";
constexpr auto kCSRSparseMatrixToSparseTensorOpName = "CSRSparseMatrixToSparseTensor";
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
constexpr auto kCumprodOpName = "Cumprod";

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bucketize.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
const char *kBucketize = "Bucketize";
const int64_t kParallelDataNumSameShape = 64 * 1024;
const int64_t kParallelDataNumSameShapeMid = 35 * 1024;
#define BUCKETIZE_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = BucketizeCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Bucketize kernel compute failed."); \
return result; \
} \
break; \
}
int64_t get_tensor_length(aicpu::Tensor *t) {
std::vector<int64_t> dim_sizes = t->GetTensorShape()->GetDimSizes();
int64_t length = 1;
for (auto x : dim_sizes) {
length *= x;
}
return length;
}
} // namespace
namespace aicpu {
uint32_t BucketizeCpuKernel::Compute(CpuKernelContext &ctx) {
// normal check
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Bucketize check input and output number failed.");
auto data_type = ctx.Input(0)->GetDataType();
KERNEL_CHECK_NULLPTR(ctx.GetAttr("boundaries"), KERNEL_STATUS_PARAM_INVALID, "Get boundaries failed")
// check input datatype
Tensor *input = ctx.Input(kFirstInputIndex);
DataType dt_input = input->GetDataType();
KERNEL_CHECK_FALSE((dt_input == DT_FLOAT || dt_input == DT_INT32 || dt_input == DT_INT64 || dt_input == DT_DOUBLE),
KERNEL_STATUS_PARAM_INVALID,
"Input data type must DT_FLOAT or DT_INT32 or DT_INT64 or DT_DOUBLE,"
"but got data type[%s].",
DTypeStr(dt_input).c_str());
// check output datatype
Tensor *output = ctx.Output(kFirstOutputIndex);
DataType dt_output = output->GetDataType();
KERNEL_CHECK_FALSE((dt_output == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
"Output data type must DT_INT32, but got data type[%s].", DTypeStr(dt_output).c_str());
auto input_sizes = input->GetTensorShape()->GetDimSizes();
auto output_sizes = output->GetTensorShape()->GetDimSizes();
KERNEL_CHECK_FALSE((input_sizes == output_sizes), KERNEL_STATUS_PARAM_INVALID,
"The tensor shape of input [%s] need be same with "
"output [%s].",
VectorToString(input_sizes).c_str(), VectorToString(output_sizes).c_str());
switch (data_type) {
BUCKETIZE_COMPUTE_CASE(DT_INT32, int32_t, ctx)
BUCKETIZE_COMPUTE_CASE(DT_INT64, int64_t, ctx)
BUCKETIZE_COMPUTE_CASE(DT_FLOAT, float, ctx)
BUCKETIZE_COMPUTE_CASE(DT_DOUBLE, double, ctx)
default:
KERNEL_LOG_ERROR("Bucketize kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t BucketizeCpuKernel::BucketizeCompute(CpuKernelContext &ctx) {
const int64_t data_num = get_tensor_length(ctx.Input(0));
auto boundaries = ctx.GetAttr("boundaries");
std::vector<float> boundaries_data = boundaries->GetListFloat();
std::sort(boundaries_data.begin(), boundaries_data.end());
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output_data = reinterpret_cast<int32_t *>(ctx.Output(0)->GetData());
if (data_num >= kParallelDataNumSameShape) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumSameShapeMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_bucketize = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
output_data[i] = first_bigger_it - boundaries_data.begin();
}
};
CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_bucketize);
} else {
for (int64_t i = 0; i < data_num; i++) {
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
output_data[i] = first_bigger_it - boundaries_data.begin();
}
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kBucketize, BucketizeCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_BUCKETIZE_H_
#define AICPU_KERNELS_NORMALIZED_BUCKETIZE_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class BucketizeCpuKernel : public CpuKernel {
public:
BucketizeCpuKernel() = default;
~BucketizeCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
template <typename T>
static uint32_t BucketizeCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cauchy.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include <vector>
#include <random>
namespace {
const char *kCauchy = "Cauchy";
const int64_t kParallelDataNums = 64 * 1024;
const uint32_t knum = 2;
} // namespace
#define CAUCHY_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = CauchyCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Cauchy kernel compute failed."); \
return result; \
} \
break; \
}
namespace aicpu {
uint32_t CauchyCpuKernel::Compute(CpuKernelContext &ctx) {
Tensor *output_tensor = ctx.Output(0);
auto output_dtype = output_tensor->GetDataType();
KERNEL_CHECK_NULLPTR(output_tensor, KERNEL_STATUS_INNER_ERROR, "CauchyCompute check output_tensor is nullptr.");
switch (output_dtype) {
CAUCHY_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
CAUCHY_COMPUTE_CASE(DT_FLOAT, float, ctx)
default:
KERNEL_LOG_ERROR("Cauchy kernel data type [%s] not support.", DTypeStr(output_dtype).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
};
template <typename T>
uint32_t CauchyCpuKernel::CauchyCompute(CpuKernelContext &ctx) {
AttrValue *median = ctx.GetAttr("median");
if (median != nullptr) {
median_ = median->GetFloat();
}
AttrValue *sigma = ctx.GetAttr("sigma");
if (sigma != nullptr) {
sigma_ = sigma->GetFloat();
}
Tensor *y_tensor = ctx.Output(0);
AttrValue *output_size_attr = ctx.GetAttr("size");
KERNEL_CHECK_NULLPTR(output_size_attr, KERNEL_STATUS_PARAM_INVALID, "CauchyCompute get size failed.");
std::vector<int64_t> output_size = ctx.GetAttr("size")->GetListInt();
if (output_size.empty()) {
KERNEL_LOG_ERROR("CauchyCompute get size is empty.");
return KERNEL_STATUS_PARAM_INVALID;
}
auto y_shape = y_tensor->GetTensorShape();
y_shape->SetDimSizes(output_size);
int64_t y_num = y_tensor->NumElements();
KERNEL_CHECK_NULLPTR(y_tensor->GetData(), KERNEL_STATUS_INNER_ERROR, "CauchyCompute check output_data is nullptr.");
T *y_data = static_cast<T *>(y_tensor->GetData());
std::default_random_engine generator(std::random_device{}());
std::cauchy_distribution<float> cauchy_d(median_, sigma_);
uint32_t max_core_num = 1;
if (y_num >= kParallelDataNums) {
max_core_num = std::max(max_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - knum);
if (max_core_num > y_num) {
max_core_num = y_num;
}
}
auto Cauchy_d = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float data = cauchy_d(generator);
y_data[i] = static_cast<T>(data);
}
};
uint32_t ret = CpuKernelUtils::ParallelFor(ctx, y_num, y_num / max_core_num, Cauchy_d);
if (ret != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("CpuKernelUtils::ParallelFor failed.");
return KERNEL_STATUS_INNER_ERROR;
}
KERNEL_LOG_INFO("CauchyCpuKernel::ComputeCauchy end.");
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCauchy, CauchyCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CAUCHY_WINDOW_H_
#define AICPU_KERNELS_NORMALIZED_CAUCHY_WINDOW_H_
#define EIGEN_USE_THREADS
#define EIGEN_USE_SIMPLE_THREAD_POOL
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace aicpu {
class CauchyCpuKernel : public CpuKernel {
public:
CauchyCpuKernel() = default;
~CauchyCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename T>
uint32_t CauchyCompute(CpuKernelContext &ctx);
float median_ = 0.0;
float sigma_ = 1.0;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,199 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cholesky_grad.h"
#include <algorithm>
#include <iostream>
#include <map>
#include "cpu_kernel_utils.h"
namespace {
const uint32_t kInputNum = 2;
const uint32_t kOutputNum = 1;
const char *CholeskyGrad = "CholeskyGrad";
} // namespace
namespace aicpu {
uint32_t CholeskyGradCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CholeskyGrad check input and output number failed.");
Tensor *input0 = ctx.Input(0);
Tensor *input1 = ctx.Input(1);
Tensor *output0 = ctx.Output(0);
if (input0->GetDataSize() == 0 || input1->GetDataSize() == 0) {
KERNEL_LOG_ERROR("[%s] Input tensor is empty tensor.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
auto shape0 = input0->GetTensorShape();
auto shape1 = input1->GetTensorShape();
auto shape2 = output0->GetTensorShape();
if (shape0->GetDims() != shape1->GetDims() || shape1->GetDims() != shape2->GetDims()) {
KERNEL_LOG_ERROR("[%s] Inputs and Output tensors should have same dims.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
auto dims = shape0->GetDims();
if (shape0->GetDimSize(dims - 1) != shape0->GetDimSize(dims - 2)) {
KERNEL_LOG_ERROR("[%s] Tensor input0 is not square.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if ((shape0->GetDimSize(dims - 1) != shape1->GetDimSize(dims - 1)) ||
(shape0->GetDimSize(dims - 2) != shape1->GetDimSize(dims - 2))) {
KERNEL_LOG_ERROR("[%s] Tensor input0&input1's shape mismatch.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if ((shape0->GetDimSize(dims - 1) != shape2->GetDimSize(dims - 1)) ||
(shape0->GetDimSize(dims - 2) != shape2->GetDimSize(dims - 2))) {
KERNEL_LOG_ERROR("[%s] Tensor input0&output0's shape mismatch.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
auto data_type_0 = input0->GetDataType();
auto data_type_1 = input1->GetDataType();
auto data_type_2 = output0->GetDataType();
if (data_type_0 != data_type_1 || data_type_0 != data_type_2) {
KERNEL_LOG_ERROR("[%s] Tensor data type mismatch.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (data_type_0 != DT_FLOAT && data_type_0 != DT_DOUBLE) {
KERNEL_LOG_ERROR("CholeskyGrad kernel data type [%u] not support.", data_type_0);
return KERNEL_STATUS_PARAM_INVALID;
}
if (data_type_0 == DT_FLOAT) {
return ComputeKernel<float>(ctx, true);
} else {
return ComputeKernel<double>(ctx, true);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CholeskyGradCpuKernel::ComputeKernel(CpuKernelContext &ctx, const bool &reverse) {
auto dims = ctx.Input(0)->GetTensorShape()->GetDims();
auto lptr = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto gradptr = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto outputptr = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int n = ctx.Input(0)->GetTensorShape()->GetDimSize(dims - 1);
int64_t data_num = ctx.Input(0)->NumElements();
const int64_t mat_size = n * n;
const int64_t batch = data_num / mat_size;
const int64_t kParallelDataNum = 16 * mat_size;
const int64_t kParallelDataNumMid = 72 * mat_size;
if (data_num >= kParallelDataNum) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
if (data_num <= kParallelDataNumMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
auto sharder_cholesky_grad = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
ComputeMatrix(lptr + i * mat_size, gradptr + i * mat_size, outputptr + i * mat_size, n);
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch, batch / max_core_num, sharder_cholesky_grad),
"CholeskyGrad Compute failed.");
} else {
for (int64_t i = 0; i < batch; i++) {
ComputeMatrix(lptr + i * mat_size, gradptr + i * mat_size, outputptr + i * mat_size, n);
}
}
return KERNEL_STATUS_OK;
}
template <typename T>
void CholeskyGradCpuKernel::ComputeMatrix(T *lptr, T *gradptr, T *outputptr, int64_t n) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> eigengrad(n, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> eigenl(n, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> output_matrix(n, n);
for (int i = 0; i < n * n; i++) {
*(eigengrad.data() + i) = *(gradptr + i);
*(eigenl.data() + i) = *(lptr + i);
}
// Algorithm only depends on lower triangular half on input_matrix_l.
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> input_matrix_l =
eigenl.template triangularView<Eigen::Lower>();
// Algorithm only depends on lower triangular half on input_matrix_grad.
output_matrix = eigengrad.template triangularView<Eigen::Lower>();
const int64_t kMatrixSize = input_matrix_l.rows();
const int64_t kMaxBlockSize = 32;
for (int64_t block_end = kMatrixSize; block_end > 0; block_end -= kMaxBlockSize) {
const int64_t block_begin = std::max(int64_t{0}, block_end - kMaxBlockSize);
const int64_t block_size = block_end - block_begin;
const int64_t trailing_size = kMatrixSize - block_end;
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
auto B_bar = output_matrix.block(block_end, 0, trailing_size, block_begin);
auto C = input_matrix_l.block(block_end, block_begin, trailing_size, block_size);
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, block_size);
auto D = input_matrix_l.block(block_begin, block_begin, block_size, block_size);
auto D_bar = output_matrix.block(block_begin, block_begin, block_size, block_size);
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin);
C_bar = D.adjoint().template triangularView<Eigen::Upper>().solve(C_bar.adjoint()).adjoint();
D_bar -= (C_bar.adjoint() * C).template triangularView<Eigen::Lower>();
B_bar -= C_bar * R;
R_bar -= C_bar.adjoint() * B;
CholeskyGradUnblocked<T>(D, D_bar);
R_bar -= (D_bar + D_bar.adjoint()) * R;
}
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval();
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
*(outputptr + i * n + j) = output_matrix(i, j);
}
}
}
template <typename T>
void CholeskyGradCpuKernel::CholeskyGradUnblocked(
const Eigen::Ref<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> &l_block,
Eigen::Ref<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> grad_block) {
const int64_t kMatrixSize = l_block.rows();
for (int64_t k = kMatrixSize - 1; k >= 0; k--) {
const int64_t number_rows_B = kMatrixSize - (k + 1);
const int64_t number_rows_r_stack_B = number_rows_B + 1;
auto r = l_block.block(k, 0, 1, k);
auto r_bar = grad_block.block(k, 0, 1, k);
auto d = l_block(k, k); // This needs to be a scalar rather than a view.
auto d_bar = grad_block.block(k, k, 1, 1);
// B is not included explicitly because it is not used on its own.
auto B_bar = grad_block.block(k + 1, 0, number_rows_B, k);
auto c = l_block.block(k + 1, k, number_rows_B, 1);
auto c_bar = grad_block.block(k + 1, k, number_rows_B, 1);
// Result of vertical stacking d_bar and c_bar.
auto d_stack_c_bar = grad_block.block(k, k, number_rows_r_stack_B, 1);
// Result of vertical stacking of r and B.
auto r_stack_B = l_block.block(k, 0, number_rows_r_stack_B, k);
d_bar -= (c.adjoint() * c_bar) / d;
d_stack_c_bar /= d;
r_bar -= d_stack_c_bar.adjoint() * r_stack_B;
B_bar -= c_bar * r;
d_bar /= 2.;
}
}
REGISTER_CPU_KERNEL(CholeskyGrad, CholeskyGradCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKY_GRAD_H_
#define AICPU_KERNELS_NORMALIZED_CHOLESKY_GRAD_H_
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "utils/bcast.h"
#include <Eigen/Dense>
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include <unsupported/Eigen/MatrixFunctions>
namespace aicpu {
class CholeskyGradCpuKernel : public CpuKernel {
public:
CholeskyGradCpuKernel() = default;
~CholeskyGradCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename T>
uint32_t ComputeKernel(CpuKernelContext &ctx, const bool &reverse);
template <typename T>
void ComputeMatrix(T *lptr, T *gradptr, T *outputptr, int64_t n);
template <typename T>
void CholeskyGradUnblocked(
const Eigen::Ref<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> &l_block,
Eigen::Ref<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> grad_block);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cholesky_inverse.h"
#include "cpu_kernel_utils.h"
#include "utils/kernel_util.h"
#include <Eigen/Dense>
#include <iostream>
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
const uint32_t dimension = 2;
const char *kCholeskyInverse = "CholeskyInverse";
} // namespace
namespace aicpu {
uint32_t CholeskyInverseCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check CholeskyInverse params failed.");
Tensor *input = ctx.Input(0);
KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.");
Tensor *output = ctx.Output(0);
auto inputShape = input->GetTensorShape();
KERNEL_CHECK_NULLPTR(inputShape, KERNEL_STATUS_PARAM_INVALID, "Get inputShape failed.");
AttrValue *upper = ctx.GetAttr("upper");
KERNEL_CHECK_NULLPTR(upper, KERNEL_STATUS_PARAM_INVALID, "Get upper failed.");
KERNEL_LOG_DEBUG(
"CholeskyInverseCpuKernel[%s], input: size[%llu];"
"output: size[%llu].",
ctx.GetOpType().c_str(), input->GetDataSize(), output->GetDataSize());
auto input_dims = inputShape->GetDims();
if (input_dims != dimension) {
KERNEL_LOG_ERROR("CholeskyInverse input dim must be 2!");
return KERNEL_STATUS_PARAM_INVALID;
} else if (inputShape->GetDimSize(input_dims - 2) != inputShape->GetDimSize(input_dims - 1)) {
KERNEL_LOG_ERROR("CholeskyInverse input matrix must be square matrix!");
return KERNEL_STATUS_PARAM_INVALID;
}
DataType data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
case DT_FLOAT:
return CholeskyInverseCompute<float>(ctx);
case DT_DOUBLE:
return CholeskyInverseCompute<double>(ctx);
default:
KERNEL_LOG_ERROR("CholeskyInverse kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CholeskyInverseCpuKernel::CholeskyInverseCompute(CpuKernelContext &ctx) {
auto input_x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output_y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
auto inputShape = ctx.Input(0)->GetTensorShape();
int64_t n = inputShape->GetDimSize(0);
typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXd;
Eigen::Map<MatrixXd> A(input_x, n, n);
MatrixXd result;
AttrValue *upper = ctx.GetAttr("upper");
bool val = upper->GetBool();
if (val) {
result = (A.transpose() * A).inverse();
} else {
result = (A * A.transpose()).inverse();
}
for (int64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < n; j++) {
*(output_y + i * n + j) = result(i, j);
}
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCholeskyInverse, CholeskyInverseCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,19 @@
#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKYINVERSE_H_
#define AICPU_KERNELS_NORMALIZED_CHOLESKYINVERSE_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class CholeskyInverseCpuKernel : public CpuKernel {
public:
CholeskyInverseCpuKernel() = default;
~CholeskyInverseCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename T>
static uint32_t CholeskyInverseCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cholesky_solve.h"
#include <algorithm>
#include <iostream>
#include <map>
#include <Eigen/Dense>
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include "cpu_kernel_utils.h"
namespace {
const uint32_t kInputNum = 2;
const uint32_t kOutputNum = 1;
const char *CholeskySolve = "CholeskySolve";
} // namespace
namespace aicpu {
uint32_t CholeskySolveCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CholeskySolve check input and output failed.");
Tensor *input_x1 = ctx.Input(0);
AttrValue *upper = ctx.GetAttr("upper");
bool upperinfo = (upper == nullptr) ? false : upper->GetBool();
auto data_type_x1 = input_x1->GetDataType();
switch (data_type_x1) {
case DT_FLOAT:
return ComputeKernel<float>(ctx, upperinfo);
case DT_DOUBLE:
return ComputeKernel<double>(ctx, upperinfo);
default:
KERNEL_LOG_ERROR("CholeskySolve kernel data type [%s] not support.", DTypeStr(data_type_x1).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
REGISTER_CPU_KERNEL(CholeskySolve, CholeskySolveCpuKernel);
template <typename T>
uint32_t CholeskySolveCpuKernel::ComputeKernel(CpuKernelContext &ctx, const bool &upper) {
auto rhsptr = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto lhsptr = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto outptr = reinterpret_cast<T *>(ctx.Output(0)->GetData());
size_t batch_size = 1;
std::vector<int64_t> dims = ctx.Input(0)->GetTensorShape()->GetDimSizes();
size_t dimsnum = ctx.Input(0)->GetTensorShape()->GetDims();
size_t dim = dims[dimsnum - 2];
size_t rhs_dim = dims[dimsnum - 1];
if (dimsnum == 3) {
batch_size = dims[dimsnum - 3];
}
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RHS(dim, rhs_dim);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> LHS(dim, dim);
for (size_t k = 0; k < batch_size; k++) {
for (size_t i = 0; i < dim * rhs_dim; i++) {
RHS.data()[i] = rhsptr[k * dim * rhs_dim + i];
}
for (size_t i = 0; i < dim * dim; i++) {
LHS.data()[i] = lhsptr[k * dim * dim + i];
}
if (!upper) {
LHS.template triangularView<Eigen::Lower>().solveInPlace(RHS);
LHS.adjoint().template triangularView<Eigen::Upper>().solveInPlace(RHS);
} else {
LHS.adjoint().template triangularView<Eigen::Lower>().solveInPlace(RHS);
LHS.template triangularView<Eigen::Upper>().solveInPlace(RHS);
}
for (size_t i = 0; i < dim * rhs_dim; i++) {
outptr[k * dim * rhs_dim + i] = RHS.data()[i];
}
}
return KERNEL_STATUS_OK;
}
} // namespace aicpu

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKY_SOLVE_H_
#define AICPU_KERNELS_NORMALIZED_CHOLESKY_SOLVE_H_
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "utils/bcast.h"
namespace aicpu {
class CholeskySolveCpuKernel : public CpuKernel {
public:
CholeskySolveCpuKernel() = default;
~CholeskySolveCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename T>
uint32_t ComputeKernel(CpuKernelContext &ctx, const bool &upper);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,447 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "combined_non_max_suppression.h"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <queue>
#include <vector>
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kInputNum = 6;
const uint32_t kOutputNum = 4;
const char *kCombinedNonMaxSuppression = "CombinedNonMaxSuppression";
void alloc_zeros(float *arr, int arr_len) {
for (int i = 0; i < arr_len; i++) {
arr[i] = (float)0.0;
}
}
void alloc_zeros(int *arr, int arr_len) {
for (int i = 0; i < arr_len; i++) {
arr[i] = 0;
}
}
} // namespace
namespace aicpu {
// Normalize the diagonal to the input mode of bottom left and top right
void CombinedNonMaxSuppressionCpuKernel::regular_input2buffer(float **boxes_buffer, float *box_src,
const int class_idx) {
/**
* shape of box_src
* box_src[num_boxes*q*4]
* ways to visit box_src[i][class_idx][k] which stored by 1-dimension
* box_src[i][class_idx][k]=box_src[i*q*4+class_idx*4+k]
*/
int box_len1;
int sub_box_len1 = q * 4;
int box_len2 = (class_idx << 2);
for (int i = 0; i < num_boxes; i++) {
box_len1 = i * sub_box_len1 + box_len2;
if (box_src[box_len1] > box_src[box_len1 + 2]) {
boxes_buffer[i][0] = box_src[box_len1 + 2];
boxes_buffer[i][2] = box_src[box_len1 + 0];
} else {
boxes_buffer[i][0] = box_src[box_len1 + 0];
boxes_buffer[i][2] = box_src[box_len1 + 2];
}
if (box_src[box_len1 + 1] > box_src[box_len1 + 3]) {
boxes_buffer[i][1] = box_src[box_len1 + 3];
boxes_buffer[i][3] = box_src[box_len1 + 1];
} else {
boxes_buffer[i][1] = box_src[box_len1 + 1];
boxes_buffer[i][3] = box_src[box_len1 + 3];
}
}
}
// Calculate the area ratio of the intersection of two squares
float CombinedNonMaxSuppressionCpuKernel::IOU(float **boxes_buffer, int i, int j) {
const float *box_a = boxes_buffer[i];
const float *box_b = boxes_buffer[j];
float lx, ly, rx, ry;
float w, h;
float area;
float area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
float area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
if (area_a <= 0 || area_b <= 0) {
return 0.0;
}
lx = box_a[0] > box_b[0] ? box_a[0] : box_b[0];
ly = box_a[1] > box_b[1] ? box_a[1] : box_b[1];
rx = box_a[2] < box_b[2] ? box_a[2] : box_b[2];
ry = box_a[3] < box_b[3] ? box_a[3] : box_b[3];
w = rx > lx ? (rx - lx) : 0;
h = ry > ly ? (ry - ly) : 0;
area = w * h;
return area / (area_a + area_b - area);
}
/**
* if soft_nms_sigma > 0.0, soft_nms is used, means update by score=score*exp(scale*iou^2)
* if soft_nms_sigma <= 0.0, nms is used, means delete it when iou > iou_threshold
* run non max suppression per bath per class
*/
void CombinedNonMaxSuppressionCpuKernel::non_max_suppression(float **boxes_buffer, float *scores_buffer,
std::vector<int> &selected) {
std::priority_queue<non_max_suppression_local::score_index> pq;
for (int i = 0; i < num_boxes; i++) {
if (scores_buffer[i] > score_threshold) {
pq.push(non_max_suppression_local::score_index(i, scores_buffer[i], 0));
}
}
float scale = static_cast<float>(0.0);
bool is_soft_nms = soft_nms_sigma > static_cast<float>(0.0);
if (is_soft_nms) {
scale = static_cast<float>(-0.5) / soft_nms_sigma;
}
float similarity;
float original_score;
non_max_suppression_local::score_index next_si;
while (((int)selected.size() < size_per_class) && (!pq.empty())) {
next_si = pq.top();
original_score = next_si.score;
pq.pop();
bool should_hard_suppress = false;
for (int j = static_cast<int>(selected.size()) - 1; j >= next_si.suppress_begin_index; j--) {
similarity = IOU(boxes_buffer, next_si.box_index, selected[j]);
if (is_soft_nms) {
next_si.score *=
similarity <= iou_threshold ? std::exp(scale * similarity * similarity) : static_cast<float>(0.0);
}
if (!is_soft_nms && similarity > iou_threshold) {
should_hard_suppress = true;
break;
}
if (next_si.score <= score_threshold) break;
}
next_si.suppress_begin_index = selected.size();
if (!should_hard_suppress) {
if (next_si.score == original_score) {
selected.push_back(next_si.box_index);
continue;
}
if (next_si.score > score_threshold) {
pq.push(next_si);
}
}
}
}
void CombinedNonMaxSuppressionCpuKernel::nms_perclass(
float *boxes, float *scores, std::vector<non_max_suppression_local::result_para> &sub_result_vec, int &result_size) {
int k = 0;
int box_idx;
int boxe_len1;
int sub_box_len1 = q * 4;
int box_len2 = 0;
float **boxes_buffer = new float *[num_boxes]();
float *scores_buffer = new float[num_boxes]();
for (int i = 0; i < num_boxes; i++) {
boxes_buffer[i] = new float[4];
}
/**
* shape of score and boxes
* score[num_boxes*num_class]
* boxes[num_boxes*q*4]
*/
if (q == 1) {
regular_input2buffer(boxes_buffer, boxes, 0);
}
for (int j = 0; j < num_class; j++) {
for (int i = 0; i < num_boxes; i++) {
scores_buffer[i] = scores[i * num_class + j];
}
if (q > 1) {
regular_input2buffer(boxes_buffer, boxes, j);
box_len2 = j * 4;
}
std::vector<int> selected;
non_max_suppression(boxes_buffer, scores_buffer, selected);
for (int i = 0; i < (int)selected.size(); i++) {
box_idx = selected[i];
boxe_len1 = box_idx * sub_box_len1 + box_len2;
sub_result_vec[k++] = {box_idx,
scores_buffer[box_idx],
j,
{boxes[boxe_len1 + 0], boxes[boxe_len1 + 1], boxes[boxe_len1 + 2], boxes[boxe_len1 + 3]}};
}
result_size += selected.size();
}
for (int i = 0; i < num_boxes; i++) {
delete[] boxes_buffer[i];
}
delete[] boxes_buffer;
delete[] scores_buffer;
return;
}
uint32_t CombinedNonMaxSuppressionCpuKernel::nms_perbath(CpuKernelContext &ctx, float *boxes, float *scores,
float *nmsed_boxes, float *nmsed_scores, float *nmsed_class,
int *valid_detection) {
alloc_zeros(nmsed_boxes, num_bath * num_detection * 4);
alloc_zeros(nmsed_scores, num_bath * num_detection);
alloc_zeros(nmsed_class, num_bath * num_detection);
alloc_zeros(valid_detection, num_bath);
const float box_min = 0.0;
const float box_max = 1.0;
/**
* shape of scores and boxes:
* scores[num_bath*num_boxes*num_class]
* boxes[num_bath*num_boxes*q*4]
*/
int score_len2 = num_boxes * num_class;
int boxes_len2 = num_boxes * q * 4;
auto shard_nms = [&](size_t start, size_t end) {
for (int i = start; i < (int)end; i++) {
int per_detections = 0;
int scores_index = 0;
int result_size = 0;
std::vector<non_max_suppression_local::result_para> result_vec(size_per_class * num_class,
{0, 0.0, 0, {0.0, 0.0, 0.0, 0.0}});
nms_perclass(boxes + i * boxes_len2, scores + i * score_len2, result_vec, result_size);
if (!pad_per_class) {
per_detections = std::min(result_size, max_total_size);
} else {
per_detections = std::min(result_size, num_detection);
}
std::sort(result_vec.begin(), result_vec.begin() + result_size, non_max_suppression_local::result_cmp);
scores_index = i * num_detection;
for (int k = 0; k < per_detections; k++) {
if (clip_boxes) {
nmsed_boxes[(scores_index << 2) + 0] = std::max(std::min(result_vec[k].box_coord[0], box_max), box_min);
nmsed_boxes[(scores_index << 2) + 1] = std::max(std::min(result_vec[k].box_coord[1], box_max), box_min);
nmsed_boxes[(scores_index << 2) + 2] = std::max(std::min(result_vec[k].box_coord[2], box_max), box_min);
nmsed_boxes[(scores_index << 2) + 3] = std::max(std::min(result_vec[k].box_coord[3], box_max), box_min);
nmsed_scores[scores_index] = result_vec[k].score;
nmsed_class[scores_index] = (float)result_vec[k].class_idx;
} else {
nmsed_boxes[(scores_index << 2) + 0] = result_vec[k].box_coord[0];
nmsed_boxes[(scores_index << 2) + 1] = result_vec[k].box_coord[1];
nmsed_boxes[(scores_index << 2) + 2] = result_vec[k].box_coord[2];
nmsed_boxes[(scores_index << 2) + 3] = result_vec[k].box_coord[3];
nmsed_scores[scores_index] = result_vec[k].score;
nmsed_class[scores_index] = (float)result_vec[k].class_idx;
}
scores_index++;
}
valid_detection[i] = per_detections;
}
};
uint32_t min_core_num = 1;
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > num_bath) {
max_core_num = num_bath;
}
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_bath, num_bath / max_core_num, shard_nms),
"CombinedNonMaxSuppression Compute failed in nms_perbath stage.");
return KERNEL_STATUS_OK;
}
uint32_t CombinedNonMaxSuppressionCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
"CombinedNonMaxSuppression check input and output number failed.");
KERNEL_HANDLE_ERROR(CombinedNonMaxSuppressionCheck(ctx), "CombinedNonMaxSuppression check params failed.");
CombinedNonMaxSuppressionCompute(ctx);
return KERNEL_STATUS_OK;
}
uint32_t CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCheck(CpuKernelContext &ctx) {
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Input(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Input(2)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 2 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Input(3)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 3 data failed.");
if (ctx.Input(4) != nullptr) {
KERNEL_CHECK_NULLPTR(ctx.Input(4)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 4 data failed.");
}
KERNEL_CHECK_NULLPTR(ctx.Input(5)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 5 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 0 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Output(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 1 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Output(2)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 2 data failed.");
KERNEL_CHECK_NULLPTR(ctx.Output(3)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 3 data failed.");
KERNEL_CHECK_FALSE((ctx.Input(0)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of input0 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(0)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of input1 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(1)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Input(2)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
"The data type of input2 [%s] must be [DT_INT32].", DTypeStr(ctx.Input(2)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Input(3)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
"The data type of input3 [%s] must be [DT_INT32].", DTypeStr(ctx.Input(3)->GetDataType()).c_str());
if (ctx.Input(4) != NULL) {
KERNEL_CHECK_FALSE((ctx.Input(4)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of input4 [%s] must be [DT_FLOAT].",
DTypeStr(ctx.Input(4)->GetDataType()).c_str());
}
KERNEL_CHECK_FALSE((ctx.Input(5)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of input5 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(5)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Output(0)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of output0 [%s] must be [DT_FLOAT].",
DTypeStr(ctx.Output(0)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Output(1)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of output1 [%s] must be [DT_FLOAT].",
DTypeStr(ctx.Output(1)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Output(2)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
"The data type of output2 [%s] must be [DT_FLOAT].",
DTypeStr(ctx.Output(2)->GetDataType()).c_str());
KERNEL_CHECK_FALSE((ctx.Output(3)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
"The data type of output3 [%s] must be [DT_INT32].",
DTypeStr(ctx.Output(3)->GetDataType()).c_str());
auto input0_shape = ctx.Input(0)->GetTensorShape();
auto input1_shape = ctx.Input(1)->GetTensorShape();
auto input2_shape = ctx.Input(2)->GetTensorShape();
auto input3_shape = ctx.Input(3)->GetTensorShape();
auto input5_shape = ctx.Input(5)->GetTensorShape();
KERNEL_CHECK_FALSE((input0_shape->GetDims() == 4), KERNEL_STATUS_PARAM_INVALID, "The input0's dims [%d] must be 4",
input0_shape->GetDims());
KERNEL_CHECK_FALSE((input1_shape->GetDims() == 3), KERNEL_STATUS_PARAM_INVALID, "The input1's dims [%d] must be 3",
input1_shape->GetDims());
KERNEL_CHECK_FALSE(
(input2_shape->GetDims() == 0 || (input2_shape->GetDims() == 1 && input2_shape->GetDimSize(0) == 1)),
KERNEL_STATUS_PARAM_INVALID, "The input2's dims [%d] must be 0 or 1x1", input2_shape->GetDims());
KERNEL_CHECK_FALSE(
(input3_shape->GetDims() == 0 || (input3_shape->GetDims() == 1 && input3_shape->GetDimSize(0) == 1)),
KERNEL_STATUS_PARAM_INVALID, "The input3's dims [%d] must be 0 or 1x1", input3_shape->GetDims());
if (ctx.Input(4) != nullptr) {
auto input4_shape = ctx.Input(4)->GetTensorShape();
KERNEL_CHECK_FALSE(
(input4_shape->GetDims() == 0 || (input4_shape->GetDims() == 1 && input4_shape->GetDimSize(0) == 1)),
KERNEL_STATUS_PARAM_INVALID, "The input4's dims [%d] must be 0 or 1x1", input4_shape->GetDims());
}
KERNEL_CHECK_FALSE(
(input5_shape->GetDims() == 0 || (input5_shape->GetDims() == 1 && input5_shape->GetDimSize(0) == 1)),
KERNEL_STATUS_PARAM_INVALID, "The input5's dims [%d] must be 0 or 1x1", input5_shape->GetDims());
auto output0_shape = ctx.Output(0)->GetTensorShape();
auto output1_shape = ctx.Output(1)->GetTensorShape();
auto output2_shape = ctx.Output(2)->GetTensorShape();
auto output3_shape = ctx.Output(3)->GetTensorShape();
KERNEL_CHECK_FALSE((output0_shape->GetDims() == 3), KERNEL_STATUS_PARAM_INVALID, "The output0's dims [%d] must be 3",
output0_shape->GetDims());
KERNEL_CHECK_FALSE((output1_shape->GetDims() == 2), KERNEL_STATUS_PARAM_INVALID, "The output1's dims [%d] must be 2",
output1_shape->GetDims());
KERNEL_CHECK_FALSE((output2_shape->GetDims() == 2), KERNEL_STATUS_PARAM_INVALID, "The output2's dims [%d] must be 2",
output2_shape->GetDims());
KERNEL_CHECK_FALSE((output3_shape->GetDims() == 1), KERNEL_STATUS_PARAM_INVALID, "The output3's dims [%d] must be 1",
output3_shape->GetDims());
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(0) == input1_shape->GetDimSize(0)), KERNEL_STATUS_PARAM_INVALID,
"The input0's 1st dims [%d] need be same with the input1's 1st dims[%d]",
input0_shape->GetDimSize(0), input1_shape->GetDimSize(0));
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(1) == input1_shape->GetDimSize(1)), KERNEL_STATUS_PARAM_INVALID,
"The input0's 2nd dims [%d] need be same with the input1's 2nd dims[%d]",
input0_shape->GetDimSize(1), input1_shape->GetDimSize(1));
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(2) == input1_shape->GetDimSize(2) || input0_shape->GetDimSize(2) == 1),
KERNEL_STATUS_PARAM_INVALID,
"The input0's 3th dims [%d] need be same with the input1's 3th dims [%d] or 1",
input0_shape->GetDimSize(2), output1_shape->GetDimSize(2));
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(3) == 4), KERNEL_STATUS_PARAM_INVALID,
"The input0's 4th dims [%d] need be same with 4", input0_shape->GetDimSize(1));
KERNEL_CHECK_FALSE((output0_shape->GetDimSize(0) == output1_shape->GetDimSize(0) &&
output0_shape->GetDimSize(0) == output2_shape->GetDimSize(0) &&
output0_shape->GetDimSize(0) == output3_shape->GetDimSize(0)),
KERNEL_STATUS_PARAM_INVALID,
"The input0's 1st dims [%d], input1's 1st dims [%d],"
" input2's 1st dims [%d], input3's 1st dims [%d], need be same with each other",
output0_shape->GetDimSize(0), output1_shape->GetDimSize(0), output2_shape->GetDimSize(0),
output3_shape->GetDimSize(0));
KERNEL_CHECK_FALSE((output0_shape->GetDimSize(1) == output1_shape->GetDimSize(1) &&
output0_shape->GetDimSize(1) == output2_shape->GetDimSize(1)),
KERNEL_STATUS_PARAM_INVALID,
"The input0's 2nd dims [%d], input1's 2nd dims [%d], input2's 2nd dims [%d],"
" need be same with each other",
output0_shape->GetDimSize(1), output1_shape->GetDimSize(1), output2_shape->GetDimSize(1));
KERNEL_LOG_INFO(
" CombinedNonMaxSuppressionCpuKernel[%s], input0: size[%llu], "
" input1: size[%llu]",
ctx.GetOpType().c_str(), ctx.Input(0)->GetDataSize(), ctx.Input(1)->GetDataSize());
KERNEL_LOG_INFO(
" output0: size[%llu], output1: size[%llu],"
" output2: size[%llu], output3: size[%llu].",
ctx.Output(0)->GetDataSize(), ctx.Output(1)->GetDataSize(), ctx.Output(2)->GetDataSize(),
ctx.Output(3)->GetDataSize());
return KERNEL_STATUS_OK;
}
uint32_t CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCompute(CpuKernelContext &ctx) {
float *boxes = reinterpret_cast<float *>(ctx.Input(0)->GetData());
float *scores = reinterpret_cast<float *>(ctx.Input(1)->GetData());
max_output_size_per_class = *(reinterpret_cast<int *>(ctx.Input(2)->GetData()));
max_total_size = *(reinterpret_cast<int *>(ctx.Input(3)->GetData()));
iou_threshold = *(reinterpret_cast<float *>(ctx.Input(4)->GetData()));
score_threshold = *(reinterpret_cast<float *>(ctx.Input(5)->GetData()));
num_bath = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(0));
num_boxes = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(1));
q = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(2));
num_class = (int)(ctx.Input(1)->GetTensorShape()->GetDimSize(2));
pad_per_class = false;
clip_boxes = true;
if (ctx.GetAttr("pad_per_class") != nullptr) {
pad_per_class = (bool)(ctx.GetAttr("pad_per_class")->GetBool());
}
if (ctx.GetAttr("clip_boxes") != nullptr) {
clip_boxes = (bool)(ctx.GetAttr("clip_boxes")->GetBool());
}
float *nmsed_boxes = reinterpret_cast<float *>(ctx.Output(0)->GetData());
float *nmsed_scores = reinterpret_cast<float *>(ctx.Output(1)->GetData());
float *nmsed_class = reinterpret_cast<float *>(ctx.Output(2)->GetData());
int *valid_detection = reinterpret_cast<int *>(ctx.Output(3)->GetData());
auto output0_shape = ctx.Output(0)->GetTensorShape();
auto output1_shape = ctx.Output(1)->GetTensorShape();
auto output2_shape = ctx.Output(2)->GetTensorShape();
auto output3_shape = ctx.Output(3)->GetTensorShape();
size_per_class = max_output_size_per_class < num_boxes ? max_output_size_per_class : num_boxes;
soft_nms_sigma = 0.0;
if (pad_per_class) {
num_detection = std::min(max_total_size, max_output_size_per_class * num_class);
} else {
num_detection = max_total_size;
}
KERNEL_CHECK_FALSE((output0_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
"The output0's 1nd dims [%d] must be [%d]", output0_shape->GetDimSize(0), num_bath);
KERNEL_CHECK_FALSE((output1_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
"The output0's 1nd dims [%d] must be [%d]", output1_shape->GetDimSize(0), num_bath);
KERNEL_CHECK_FALSE((output2_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
"The output0's 1nd dims [%d] must be [%d]", output2_shape->GetDimSize(0), num_bath);
KERNEL_CHECK_FALSE((output3_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
"The output0's 1nd dims [%d] must be [%d]", output3_shape->GetDimSize(0), num_bath);
KERNEL_CHECK_FALSE((max_output_size_per_class > 0), KERNEL_STATUS_PARAM_INVALID,
"max_output_size_per_class [%d] must be > 0", max_output_size_per_class);
KERNEL_CHECK_FALSE((max_total_size > 0), KERNEL_STATUS_PARAM_INVALID, "max_total_size [%d] must be > 0",
max_total_size);
KERNEL_CHECK_FALSE((iou_threshold >= 0 && iou_threshold <= 1), KERNEL_STATUS_PARAM_INVALID,
"iou_threshold [%f] must be in [0,1]", iou_threshold);
KERNEL_CHECK_FALSE(((int)output0_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID,
"The output0's 2nd dims [%d] need be same with %d", output0_shape->GetDimSize(1), num_detection);
KERNEL_CHECK_FALSE(((int)output1_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID,
"The output1's 2nd dims [%d] need be same with %d", output1_shape->GetDimSize(1), num_detection);
KERNEL_CHECK_FALSE(((int)output2_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID,
"The output2's 2nd dims [%d] need be same with %d", output2_shape->GetDimSize(1), num_detection);
nms_perbath(ctx, boxes, scores, nmsed_boxes, nmsed_scores, nmsed_class, valid_detection);
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCombinedNonMaxSuppression, CombinedNonMaxSuppressionCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_COMBINEDNONMAXSUPPRESSION_H_
#define AICPU_KERNELS_NORMALIZED_COMBINEDNONMAXSUPPRESSION_H_
#include "cpu_ops_kernel.h"
#include "utils/bcast.h"
namespace non_max_suppression_local {
struct score_index {
int box_index;
float score;
int suppress_begin_index;
score_index() {}
score_index(int bi, float s, int sbi) : box_index(bi), score(s), suppress_begin_index(sbi) {}
bool operator<(const score_index &b) const {
return (score < b.score) || ((score == b.score) && (box_index > b.box_index));
}
};
struct result_para {
int box_index;
float score;
int class_idx;
float box_coord[4];
};
bool result_cmp(result_para &a, result_para &b) { return a.score > b.score; }
} // namespace non_max_suppression_local
namespace aicpu {
class CombinedNonMaxSuppressionCpuKernel : public CpuKernel {
public:
CombinedNonMaxSuppressionCpuKernel() = default;
~CombinedNonMaxSuppressionCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t CombinedNonMaxSuppressionCheck(CpuKernelContext &ctx);
uint32_t CombinedNonMaxSuppressionCompute(CpuKernelContext &ctx);
uint32_t nms_perbath(CpuKernelContext &, float *, float *, float *, float *, float *, int *);
void regular_input2buffer(float **, float *, const int);
float IOU(float **, int, int);
void non_max_suppression(float **, float *, std::vector<int> &);
void nms_perclass(float *, float *, std::vector<non_max_suppression_local::result_para> &, int &);
int num_bath;
int num_boxes;
int q;
int num_class;
// per batch size
int num_detection;
int max_total_size;
// The length of each type of selection defined by the user
int max_output_size_per_class;
// Calculation num_detection length
int size_per_class;
// When lower than a score_threshold, delete the relevant box
float score_threshold;
// When it is higher than the threshold value, according to the soft_nms_sigma determines deletion or decay
float iou_threshold;
float soft_nms_sigma;
bool pad_per_class;
bool clip_boxes;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "complex.h"
#include "Eigen/Eigen"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 2;
const char *kComplex = "Complex";
constexpr int64_t kFloatMaxNums = 8 * 128 * 1024;
constexpr int64_t kDoubleMaxNums = 16 * 128 * 1024;
#define Complex_COMPUTE_CASE(IN_DTYPE, IN_TYPE, OUT_DTYPE, CTX) \
case (IN_DTYPE): { \
switch (OUT_DTYPE) { \
case (DT_COMPLEX64): { \
uint32_t result = ComplexCompute<float, std::complex<float>>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Complex kernel compute failed."); \
return result; \
} \
break; \
} \
case (DT_COMPLEX128): { \
uint32_t result = ComplexCompute<double, std::complex<double>>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Complex kernel compute failed."); \
return result; \
} \
break; \
} \
default: \
KERNEL_LOG_ERROR("Complex kernel output data type [%s] not support.", DTypeStr(OUT_DTYPE).c_str()); \
return KERNEL_STATUS_PARAM_INVALID; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t ComplexCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kComplex);
DataType input_type = ctx.Input(0)->GetDataType();
switch (input_type) {
Complex_COMPUTE_CASE(DT_FLOAT, float, DT_COMPLEX64, ctx)
Complex_COMPUTE_CASE(DT_DOUBLE, double, DT_COMPLEX128, ctx) default
: KERNEL_LOG_ERROR("Complex kernel input data type [%s] not support.", DTypeStr(input_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T, typename t>
uint32_t ComplexCpuKernel::ComplexCompute(CpuKernelContext &ctx) {
auto input0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto input1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto output = reinterpret_cast<t *>(ctx.Output(0)->GetData());
auto data_type = ctx.Input(0)->GetDataType();
int64_t data_num = ctx.Output(0)->NumElements();
int64_t data_size = data_num * sizeof(T);
if ((data_type == DT_FLOAT && data_size <= kFloatMaxNums) ||
(data_type == DT_DOUBLE && data_size <= kDoubleMaxNums)) {
for (int64_t index = 0; index < data_num; ++index) {
*(output + index) = t(*(input0 + index), *(input1 + index));
}
} else {
uint32_t min_core_num = 1;
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto shard_complex = [&](size_t start, size_t end) {
for (size_t index = start; index < end; ++index) {
*(output + index) = t(*(input0 + index), *(input1 + index));
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_complex),
"complex Compute failed");
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kComplex, ComplexCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,40 @@
/**
* Copyright (C) 2020-2021. Huawei Technologies Co., Ltd. All rights reserved.
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*
* @brief
*
* @version 1.0
*
*/
#ifndef AICPU_KERNELS_NORMALIZED_COMPLEX_H_
#define AICPU_KERNELS_NORMALIZED_COMPLEX_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class ComplexCpuKernel : public CpuKernel {
public:
ComplexCpuKernel() = default;
~ComplexCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ComplexCheck(CpuKernelContext &ctx);
template <typename T, typename t>
uint32_t ComplexCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Jilin University
* Copyright 2020 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "complex_abs.h"
#include <unsupported/Eigen/CXX11/Tensor>
#include "cpu_kernel_utils.h"
#include "cpu_types.h"
#include "kernel_log.h"
#include "status.h"
#include "utils/kernel_util.h"
namespace {
const std::uint32_t kComplexAbsInputNum{1u};
const std::uint32_t kComplexAbsOutputNum{1u};
const char *kComplexAbs{"ComplexAbs"};
const std::int64_t kComplexAbsParallelNum{32 * 1024};
} // namespace
namespace aicpu {
namespace detail {
template <typename T>
inline const typename T::value_type ScalarComplexAbs(const T &x) {
return std::abs(x);
}
inline std::uint32_t ParallelForComplexAbs(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size,
const std::function<void(std::int64_t, std::int64_t)> &work) {
if (total > kComplexAbsParallelNum)
return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work);
else
work(0, total);
return KERNEL_STATUS_OK;
}
template <typename T>
inline std::uint32_t ComputeComplexAbsKernel(const CpuKernelContext &ctx) {
T *input0{static_cast<T *>(ctx.Input(0)->GetData())};
typename T::value_type *output{static_cast<typename T::value_type *>(ctx.Output(0)->GetData())};
std::int64_t total{ctx.Input(0)->NumElements()};
std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)};
std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)};
return ParallelForComplexAbs(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) {
std::transform(input0 + begin, input0 + end, output + begin, ScalarComplexAbs<T>);
});
}
template <typename T>
inline std::uint32_t ComputeComplexAbs(const CpuKernelContext &ctx) {
std::uint32_t result{ComputeComplexAbsKernel<T>(ctx)};
if (result != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("ComplexAbs compute failed.");
}
return result;
}
inline std::uint32_t ExtraCheckComplexAbs(const CpuKernelContext &ctx) {
if (ctx.Input(0)->GetDataType() == ctx.Output(0)->GetDataType()) {
KERNEL_LOG_ERROR(
"The data type of the input [%s] should not be the same as the output "
"[%s].",
DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize() * 2) {
KERNEL_LOG_ERROR(
"The data size of the input [%llu] need be as twice as the output "
"[%llu].",
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
inline std::uint32_t CheckComplexAbs(CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) {
return NormalCheck(ctx, kComplexAbsInputNum, kComplexAbsOutputNum) ? KERNEL_STATUS_PARAM_INVALID
: ExtraCheckComplexAbs(ctx);
}
inline std::uint32_t ComputeComplexAbs(const CpuKernelContext &ctx) {
DataType input_type{ctx.Input(0)->GetDataType()};
switch (input_type) {
case DT_COMPLEX64:
return ComputeComplexAbs<std::complex<std::float_t>>(ctx);
case DT_COMPLEX128:
return ComputeComplexAbs<std::complex<std::double_t>>(ctx);
default:
KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
} // namespace detail
std::uint32_t ComplexAbsCpuKernel::Compute(CpuKernelContext &ctx) {
return detail::CheckComplexAbs(ctx, kComplexAbsInputNum, kComplexAbsOutputNum) ? KERNEL_STATUS_PARAM_INVALID
: detail::ComputeComplexAbs(ctx);
}
REGISTER_CPU_KERNEL(kComplexAbs, ComplexAbsCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Jilin University
* Copyright 2020 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_COMPLEX_ABS_H_
#define AICPU_KERNELS_NORMALIZED_COMPLEX_ABS_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class ComplexAbsCpuKernel final : public CpuKernel {
virtual std::uint32_t Compute(CpuKernelContext &ctx) override final;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,303 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "concat.h"
#include "utils/kernel_util.h"
using namespace std;
namespace {
const char *const Concat = "Concat";
}
namespace aicpu {
uint32_t ConcatCpuKernel::CheckAndInitParams(CpuKernelContext &ctx) {
if (ctx.GetAttr("N") == nullptr) {
n_ = 1;
} else {
AttrValue *n_ptr = ctx.GetAttr("N");
n_ = n_ptr->GetInt();
}
// "x" is a list of at least 2 "tensor" objects of the same type
KERNEL_CHECK_FALSE((n_ >= 2), KERNEL_STATUS_PARAM_INVALID, "Attr N must >= 2, but got attr N[%lld]", n_);
uint32_t input_num = ctx.GetInputsSize();
// input_num is n_(concat tensor num) + 1(concat_dim)
KERNEL_CHECK_FALSE((static_cast<int64_t>(input_num) - 1 == n_), KERNEL_STATUS_PARAM_INVALID,
"Input num must equal attr N[%lld + 1],"
"but got input num[%u]",
n_, input_num);
Tensor *concat_dim_ptr = ctx.Input(0);
KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim failed.");
auto concat_dim_shape_ptr = concat_dim_ptr->GetTensorShape();
KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim shape failed.");
int32_t concat_dim_dims = concat_dim_shape_ptr->GetDims();
KERNEL_CHECK_FALSE((concat_dim_dims == 0) || ((concat_dim_dims == 1) && (concat_dim_shape_ptr->NumElements() == 1)),
KERNEL_STATUS_PARAM_INVALID, "Input concat_dim should be a scalar integer, but got rank[%d].",
concat_dim_dims);
int32_t concat_dim = 0;
DataType concat_dim_data_type = concat_dim_ptr->GetDataType();
KERNEL_CHECK_FALSE((concat_dim_data_type == DT_INT32 || concat_dim_data_type == DT_INT64),
KERNEL_STATUS_PARAM_INVALID,
"Input concat_dim data type must DT_INT32 or DT_INT64,"
"but got data type[%d].",
DTypeStr(concat_dim_data_type).c_str());
auto concat_dim_data_ptr = concat_dim_ptr->GetData();
KERNEL_CHECK_NULLPTR(concat_dim_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim data failed.");
if (concat_dim_data_type == DT_INT32) {
concat_dim = static_cast<int64_t>(*reinterpret_cast<int32_t *>(concat_dim_data_ptr));
} else {
concat_dim = *reinterpret_cast<int64_t *>(concat_dim_data_ptr);
}
Tensor *input0_ptr = ctx.Input(1);
auto input0_type_ptr = input0_ptr->GetDataType();
for (int64_t i = 1; i < n_; i++) {
Tensor *inputi_ptr = ctx.Input(i);
auto inputi_shape_ptr = inputi_ptr->GetTensorShape();
auto inputi_type_ptr = inputi_ptr->GetDataType();
KERNEL_CHECK_NULLPTR(inputi_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input xi failed.");
KERNEL_CHECK_NULLPTR(inputi_shape_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input xi shape failed.");
KERNEL_CHECK_FALSE((input0_type_ptr == inputi_type_ptr), KERNEL_STATUS_PARAM_INVALID,
"Input tensor should have same type"
"but got %d and %d.",
DTypeStr(input0_type_ptr).c_str(), DTypeStr(inputi_type_ptr).c_str());
}
auto input0_shape_ptr = input0_ptr->GetTensorShape();
input_dims_ = input0_shape_ptr->GetDims();
data_type_ = input0_ptr->GetDataType();
axis_ = concat_dim < 0 ? concat_dim + input_dims_ : concat_dim;
KERNEL_CHECK_FALSE((0 <= axis_ && axis_ < input_dims_), KERNEL_STATUS_PARAM_INVALID,
"Input concat_dim need in the "
"range[%d, %d), but got %lld.",
-input_dims_, input_dims_, concat_dim);
inputs_flat_dim0_ = 1;
for (uint32_t d = 0; d < axis_; ++d) {
inputs_flat_dim0_ *= input0_shape_ptr->GetDimSize(d);
}
return KERNEL_STATUS_OK;
}
uint32_t ConcatCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_CHECK_FALSE((CheckAndInitParams(ctx) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID,
"CheckAndInitParams failed.");
switch (data_type_) {
case DT_FLOAT16:
return DoCompute<Eigen::half>(ctx);
case DT_FLOAT:
return DoCompute<float>(ctx);
case DT_DOUBLE:
return DoCompute<double>(ctx);
case DT_INT8:
return DoCompute<int8_t>(ctx);
case DT_INT16:
return DoCompute<int16_t>(ctx);
case DT_INT32:
return DoCompute<int32_t>(ctx);
case DT_INT64:
return DoCompute<int64_t>(ctx);
case DT_UINT8:
return DoCompute<uint8_t>(ctx);
case DT_UINT16:
return DoCompute<uint16_t>(ctx);
case DT_UINT32:
return DoCompute<uint32_t>(ctx);
case DT_UINT64:
return DoCompute<uint64_t>(ctx);
case DT_COMPLEX64:
return DoCompute<complex<float>>(ctx);
case DT_COMPLEX128:
return DoCompute<complex<double>>(ctx);
default:
KERNEL_LOG_ERROR("unsupported datatype[%d]", DTypeStr(data_type_).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
template <typename T>
uint32_t ConcatCpuKernel::PrepareInput(CpuKernelContext &ctx,
std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs) {
inputs.reserve(n_);
output_concat_dim_ = 0;
auto input0_shape_ptr = ctx.Input(1)->GetTensorShape();
for (uint32_t i = 0; i < n_; ++i) {
Tensor *input_i_ptr = ctx.Input(i + 1);
int64_t input_i_num = input_i_ptr->NumElements();
if (input_i_num == 0) {
continue;
}
auto input_i_shape_ptr = input_i_ptr->GetTensorShape();
int32_t input_i_dims = input_i_shape_ptr->GetDims();
KERNEL_CHECK_FALSE((input_i_dims == input_dims_), KERNEL_STATUS_PARAM_INVALID,
"Ranks of inputs should match: shape[0]=%d vs. shape[%u]=%d", input_dims_, i, input_i_dims);
for (int32_t j = 0; j < input_dims_; ++j) {
int64_t dim_ij = input_i_shape_ptr->GetDimSize(j);
if (j == axis_) {
output_concat_dim_ += input_i_dims > 0 ? dim_ij : 1;
continue;
}
int64_t dim_0j = input0_shape_ptr->GetDimSize(j);
KERNEL_CHECK_FALSE((dim_0j == dim_ij), KERNEL_STATUS_PARAM_INVALID,
"Dimensions of inputs should match: shape[0][%d]=%lld vs."
"shape[%u][%d]=%lld",
j, dim_0j, i, j, dim_ij);
}
int64_t inputs_flat_dim1 = input_i_num / inputs_flat_dim0_;
auto input_i_data_ptr = input_i_ptr->GetData();
KERNEL_CHECK_NULLPTR(input_i_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input x%u data failed.", i);
auto input_i = std::make_shared<typename TTypes<T>::ConstMatrix>(reinterpret_cast<T *>(input_i_data_ptr),
inputs_flat_dim0_, inputs_flat_dim1);
KERNEL_CHECK_NULLPTR(input_i, KERNEL_STATUS_PARAM_INVALID, "Create input x%u failed!", i);
inputs.emplace_back(std::move(input_i));
}
if (input_dims_ == 0) {
output_concat_dim_ = n_;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ConcatCpuKernel::PrepareOutput(CpuKernelContext &ctx, std::shared_ptr<typename TTypes<T>::Matrix> &output) {
Tensor *output_ptr = ctx.Output(0);
KERNEL_CHECK_NULLPTR(output_ptr, KERNEL_STATUS_PARAM_INVALID, "Get output failed.");
auto output_data_ptr = output_ptr->GetData();
KERNEL_CHECK_NULLPTR(output_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get output data failed.");
int64_t output_num = output_ptr->NumElements();
int64_t output_dim1 = output_num / inputs_flat_dim0_;
output = std::make_shared<typename TTypes<T>::Matrix>(reinterpret_cast<T *>(output_data_ptr), inputs_flat_dim0_,
output_dim1);
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Create output matrix failed.");
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ConcatCpuKernel::DoCompute(CpuKernelContext &ctx) {
std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> inputs;
KERNEL_CHECK_FALSE((PrepareInput<T>(ctx, inputs) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID,
"PrepareInput failed.");
std::shared_ptr<typename TTypes<T>::Matrix> output = nullptr;
KERNEL_CHECK_FALSE((PrepareOutput<T>(ctx, output) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID,
"PrepareOutput failed.");
if (inputs.size() > 0) {
return ConcatCompute<T>(ctx, inputs, output);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ConcatCpuKernel::ConcatCompute(CpuKernelContext &ctx,
const std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs,
std::shared_ptr<typename TTypes<T>::Matrix> &output) {
size_t num_inputs = inputs.size();
std::vector<ptrdiff_t> sizes;
sizes.reserve(num_inputs);
int64_t row_size = 0;
for (const auto &input : inputs) {
sizes.push_back(input->dimension(1));
row_size += sizes.back();
}
uint32_t ret = KERNEL_STATUS_OK;
auto work = [&row_size, &sizes, &inputs, &output, &num_inputs, &ret](int64_t start, int64_t end) {
if (row_size == 0) {
ret = KERNEL_STATUS_PARAM_INVALID;
return;
}
int64_t skipped_rows = start / row_size;
T *out = output->data() + skipped_rows * row_size;
T *out_start = output->data() + start;
T *out_end = output->data() + end;
// Handle partial row at start
if (out < out_start) {
for (size_t j = 0; j < num_inputs; ++j) {
ptrdiff_t size = sizes[j];
ptrdiff_t offset = out_start - out;
if (size <= offset) {
out += size;
continue;
}
const T *inp = &(*inputs[j])(skipped_rows, 0);
if (offset > 0) {
out += offset;
inp += offset;
size -= offset;
}
size = std::min(size, out_end - out);
KERNEL_CHECK_FALSE_EXEC((size > 0), break)
size_t copy_size = size * sizeof(T);
error_t ret = memcpy_s(out, copy_size, inp, copy_size);
if (ret != EOK) {
KERNEL_LOG_ERROR("Memcpy failed.");
ret = KERNEL_STATUS_INNER_ERROR;
return;
}
out += size;
}
++skipped_rows;
}
if (out < out_start || out > out_end) {
KERNEL_LOG_ERROR("Out[%llx] not in range[%llx, %llx)", out, out_start, out_end);
ret = KERNEL_STATUS_INNER_ERROR;
return;
}
// Copy remaining data.
std::vector<const T *> inp;
inp.reserve(num_inputs);
for (const auto &input : inputs) {
inp.push_back(&(*input)(skipped_rows, 0));
}
const int64_t dim0 = output->dimension(0);
for (int64_t i = skipped_rows; i < dim0; ++i) {
for (int64_t j = 0; j < static_cast<int64_t>(num_inputs); ++j) {
ptrdiff_t size = std::min(sizes[j], out_end - out);
size_t copy_size = size * sizeof(T);
auto ret = memcpy_s(out, copy_size, inp[j], copy_size);
if (ret != EOK) {
KERNEL_LOG_ERROR("Memcpy size[%zu] from inp[%llx] to out[%llx] failed.", copy_size, inp[j], out);
ret = KERNEL_STATUS_INNER_ERROR;
return;
}
out += size;
inp[j] += size;
KERNEL_CHECK_FALSE_EXEC((out != out_end), return );
}
}
};
const int64_t kParallelDataNumSameShapeBig = 255 * 1024;
int64_t data_num = output->size();
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
if (data_num >= kParallelDataNumSameShapeBig) {
max_core_num = std::min(max_core_num, 6U);
} else {
max_core_num = std::min(max_core_num, 1U);
}
if (max_core_num == 0) {
KERNEL_LOG_ERROR("max_core_num could not be 0.");
}
CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, work);
KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), KERNEL_STATUS_INNER_ERROR, "ConcatCpuKernel failed.");
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(Concat, ConcatCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,76 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CONCAT_H_
#define AICPU_KERNELS_NORMALIZED_CONCAT_H_
#include <memory>
#include <vector>
#include "cpu_ops_kernel.h"
#include "cpu_kernel_utils.h"
#include "kernel_log.h"
#include "securec.h"
#include "status.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace aicpu {
const uint32_t NumIndices = 2;
template <typename T>
struct TTypes {
// Rank-2 tensor (matrix) of scalar type T.
using Matrix = Eigen::TensorMap<Eigen::Tensor<T, NumIndices, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>;
using ConstMatrix =
Eigen::TensorMap<Eigen::Tensor<const T, NumIndices, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>;
};
class ConcatCpuKernel : public CpuKernel {
public:
ConcatCpuKernel()
: data_type_(DT_DOUBLE), input_dims_(0), n_(0), output_concat_dim_(0), axis_(0), inputs_flat_dim0_(0) {}
~ConcatCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t CheckAndInitParams(CpuKernelContext &ctx);
template <typename T>
uint32_t PrepareInput(CpuKernelContext &ctx, std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs);
template <typename T>
uint32_t PrepareOutput(CpuKernelContext &ctx, std::shared_ptr<typename TTypes<T>::Matrix> &output);
template <typename T>
uint32_t DoCompute(CpuKernelContext &ctx);
template <typename T>
uint32_t ConcatCompute(CpuKernelContext &ctx,
const std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs,
std::shared_ptr<typename TTypes<T>::Matrix> &output);
private:
DataType data_type_;
int32_t input_dims_;
int64_t n_;
int64_t output_concat_dim_;
int64_t axis_;
int64_t inputs_flat_dim0_;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,123 @@
/**
* Copyright 2021 Jilin University
* Copyright 2020 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cos.h"
#include <unsupported/Eigen/CXX11/Tensor>
#include "cpu_kernel_utils.h"
#include "cpu_types.h"
#include "kernel_log.h"
#include "status.h"
#include "utils/kernel_util.h"
namespace {
const std::uint32_t kCosInputNum{1u};
const std::uint32_t kCosOutputNum{1u};
const char *kCos{"Cos"};
const std::int64_t kCosParallelNum{64 * 1024};
} // namespace
namespace aicpu {
namespace detail {
template <typename T>
inline T ScalarCos(T x) {
return std::cos(x);
}
template <>
inline Eigen::half ScalarCos(Eigen::half x) {
const Eigen::half val{static_cast<Eigen::half>(std::cos(static_cast<std::float_t>(x)))};
return Eigen::half_impl::isnan(val) ? Eigen::half{0.0f} : val;
}
inline std::uint32_t ParallelForCos(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size,
const std::function<void(std::int64_t, std::int64_t)> &work) {
if (total > kCosParallelNum)
return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work);
else
work(0, total);
return KERNEL_STATUS_OK;
}
template <typename T>
inline std::uint32_t ComputeCosKernel(const CpuKernelContext &ctx) {
T *input0{static_cast<T *>(ctx.Input(0)->GetData())};
T *output{static_cast<T *>(ctx.Output(0)->GetData())};
std::int64_t total{ctx.Input(0)->NumElements()};
std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)};
std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)};
return ParallelForCos(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) {
std::transform(input0 + begin, input0 + end, output + begin, ScalarCos<T>);
});
}
template <typename T>
inline std::uint32_t ComputeCos(const CpuKernelContext &ctx) {
std::uint32_t result{ComputeCosKernel<T>(ctx)};
if (result != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("Cos compute failed.");
}
return result;
}
inline std::uint32_t ExtraCheckCos(const CpuKernelContext &ctx) {
if (ctx.Input(0)->GetDataType() != ctx.Output(0)->GetDataType()) {
KERNEL_LOG_ERROR("The data type of the input [%s] need be the same as the output [%s].",
DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize()) {
KERNEL_LOG_ERROR(
"The data size of the input [%llu] need be the same as the output "
"[%llu].",
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
inline std::uint32_t CheckCos(CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) {
return NormalCheck(ctx, kCosInputNum, kCosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : ExtraCheckCos(ctx);
}
inline std::uint32_t ComputeCos(const CpuKernelContext &ctx) {
DataType input_type{ctx.Input(0)->GetDataType()};
switch (input_type) {
case DT_FLOAT16:
return ComputeCos<Eigen::half>(ctx);
case DT_FLOAT:
return ComputeCos<std::float_t>(ctx);
case DT_DOUBLE:
return ComputeCos<std::double_t>(ctx);
case DT_COMPLEX64:
return ComputeCos<std::complex<std::float_t>>(ctx);
case DT_COMPLEX128:
return ComputeCos<std::complex<std::double_t>>(ctx);
default:
KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
} // namespace detail
std::uint32_t CosCpuKernel::Compute(CpuKernelContext &ctx) {
return detail::CheckCos(ctx, kCosInputNum, kCosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : detail::ComputeCos(ctx);
}
REGISTER_CPU_KERNEL(kCos, CosCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Jilin University
* Copyright 2020 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_COS_H_
#define AICPU_KERNELS_NORMALIZED_COS_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class CosCpuKernel final : public CpuKernel {
virtual std::uint32_t Compute(CpuKernelContext &ctx) override final;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,272 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "count_nonzero.h"
#include <algorithm>
#include <complex>
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kCountNonZeroInputNum = 1;
const uint32_t kCountNonZeroOutputNum = 1;
const int64_t kParallelNum = 2 * 1024;
const char *kCountNonZero = "CountNonZero";
// The following code is used to handle the general case.
// Params to use in ParallelIterator construction.
std::vector<int64_t> cnz_dims;
std::vector<int64_t> cnz_transposed_shape;
int64_t cnz_stride;
// Class def of ParallelIterator.
class ParallelIterator {
public:
ParallelIterator(std::vector<int64_t> transposed_shape, std::vector<int64_t> dims,
const std::vector<int64_t> &input_shape);
~ParallelIterator() = default;
void Next();
void Set(int64_t pos);
inline int64_t Get() const { return _pos; };
private:
int64_t _dimension{0};
std::vector<int64_t> _coord;
std::vector<int64_t> _shape;
std::vector<int64_t> _strides;
std::vector<int64_t> _back_strides;
std::vector<int64_t> _dims;
int64_t _pos{0};
};
ParallelIterator::ParallelIterator(std::vector<int64_t> transposed_shape, std::vector<int64_t> dims,
const std::vector<int64_t> &input_shape)
: _dimension(transposed_shape.size()),
_coord(transposed_shape.size(), 0),
_shape(transposed_shape),
_strides(transposed_shape.size(), 1),
_back_strides(transposed_shape.size(), 1),
_dims(dims),
_pos(0) {
std::vector<int64_t> strides(_dimension, 1);
for (int64_t i = _dimension - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * input_shape[i + 1];
}
for (int64_t i = _dimension - 1; i >= 0; --i) {
_strides[i] = strides[_dims[i]];
_back_strides[i] = (_shape[i] - 1) * _strides[i];
}
}
void ParallelIterator::Set(int64_t pos) {
for (int64_t i = _dimension - 1; i >= 0 && pos != 0; --i) {
_coord[i] = pos % _shape[i];
_pos += _coord[i] * _strides[i];
pos /= _shape[i];
}
}
void ParallelIterator::Next() {
for (int64_t i = _dimension - 1; i >= 0; --i) {
if (_coord[i] + 1 == _shape[i]) {
_coord[i] = 0;
_pos -= _back_strides[i];
} else {
_coord[i]++;
_pos += _strides[i];
break;
}
}
}
// The two structs is used for tag dispatch in IsNonZero.
template <class T>
struct is_complex_t : std::false_type {};
template <class T>
struct is_complex_t<std::complex<T>> : std::true_type {};
template <class T>
int64_t IsNonZero(T val, std::true_type) {
return val.real() != 0 || val.imag() != 0 ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
}
template <class T>
int64_t IsNonZero(T val, std::false_type) {
return val != static_cast<T>(0) ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
}
} // namespace
namespace aicpu {
template <class T>
uint32_t CountNonZeroComputeImpl(CpuKernelContext &ctx) {
Tensor *x_tensor = ctx.Input(kFirstInputIndex);
Tensor *y_tensor = ctx.Output(kFirstOutputIndex);
const T *x_ptr = reinterpret_cast<const T *>(x_tensor->GetData());
int64_t *y_ptr = reinterpret_cast<int64_t *>(y_tensor->GetData());
int64_t data_num = y_tensor->NumElements();
int64_t input_data_num = x_tensor->NumElements();
std::vector<int64_t> input_shape = x_tensor->GetTensorShape()->GetDimSizes();
// For scalar_reduction, start=0, end=input_data_num, cannot be parallelized.
auto count_nonzero_scalar_shard = [&](int64_t start, int64_t end) {
y_ptr[0] = 0;
for (int64_t i = start; i < end; ++i) {
y_ptr[0] += IsNonZero<T>(x_ptr[i], is_complex_t<T>{});
}
};
// For general case. Can be parallelized but performance is not good.
auto count_nonzero_shard = [&](int64_t start, int64_t end) {
ParallelIterator iter(cnz_transposed_shape, cnz_dims, input_shape);
iter.Set(start * cnz_stride);
for (int64_t i = start; i < end; ++i) {
int64_t reduce_initial = static_cast<int64_t>(0);
for (int64_t j = 0; j < cnz_stride; ++j) {
reduce_initial += IsNonZero<T>(x_ptr[iter.Get()], is_complex_t<T>{});
iter.Next();
}
y_ptr[i] = reduce_initial;
}
};
if (data_num == 1) {
count_nonzero_scalar_shard(0, input_data_num);
} else if (data_num > kParallelNum) {
CpuKernelUtils::ParallelFor(ctx, data_num, 1, count_nonzero_shard);
} else {
count_nonzero_shard(0, data_num);
}
return KERNEL_STATUS_OK;
}
uint32_t CountNonZeroDimsCheckAndParse(CpuKernelContext &ctx) {
std::vector<int64_t> input_shape = ctx.Input(kFirstInputIndex)->GetTensorShape()->GetDimSizes();
int64_t input_rank = input_shape.size();
std::vector<int64_t> dims{};
auto dims_attr = ctx.GetAttr("dims");
if (dims_attr != nullptr) {
dims = dims_attr->GetListInt();
}
if (dims.size() == 0) {
for (int64_t i = 0; i < input_rank; ++i) {
dims.push_back(i);
}
}
// Check dims in [-x_rank, x_rank)
for (auto &dim : dims) {
if (dim < 0) {
dim += input_rank;
}
KERNEL_CHECK_FALSE(dim < input_rank && dim >= 0, KERNEL_STATUS_PARAM_INVALID,
"[CountNonZero] dims must be in [-x_rank, x_rank).");
}
std::sort(dims.begin(), dims.end());
dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
int64_t stride_ = static_cast<int64_t>(1);
std::vector<int64_t> transposed_shape(input_rank);
// axes is the transpose of indices.
// For example, if input_rank = 5, dims = [1, 3],
// then axes =[0, 2, 4] + [1, 3].
// Initial value if axes is [?, ?, ?, ?, ?]
std::vector<int64_t> axes_(input_rank);
int64_t j = static_cast<int64_t>(0), k = static_cast<int64_t>(0);
// Put dim indices to keep to front of axes and calculate stride.
// After this operation, axes becomes [0, 2, 4] + [?, ?],
// and stride becomes 1 * 2 * 4
for (int64_t i = 0; i < input_rank; i++) {
if (j == static_cast<int64_t>(dims.size()) || i != dims[j]) {
axes_[k] = i;
++k;
} else {
stride_ *= input_shape[i];
++j;
}
}
// Put dim indices to reduce to back of axes.
// After this operation, axes becomes [0, 2, 4] + [1, 3]
for (auto &dim : dims) {
axes_[k] = dim;
++k;
}
// Calculate transposed_shape using axes.
// For example, if input_shape = (3, 4, 5, 6, 7), axes = [0, 2, 4, 1, 3],
// then transposed_shape = (3, 5, 7) + (4, 6)
std::vector<int64_t> transposed_shape_(input_rank);
for (int64_t i = 0; i < input_rank; ++i) {
transposed_shape_[i] = input_shape[axes_[i]];
}
// Assign values.
cnz_stride = stride_, cnz_transposed_shape = transposed_shape_, cnz_dims = axes_;
return KERNEL_STATUS_OK;
}
uint32_t CountNonZeroCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kCountNonZeroInputNum, kCountNonZeroOutputNum),
"[%s] check input and output failed.", kCountNonZero);
KERNEL_HANDLE_ERROR(CountNonZeroDimsCheckAndParse(ctx), "[%s] check & parse dims failed.", kCountNonZero);
auto y_data_type = ctx.Output(kFirstOutputIndex)->GetDataType();
KERNEL_CHECK_FALSE(y_data_type == DT_INT64, KERNEL_STATUS_PARAM_INVALID,
"[CountNonZero] Data type of output not supported, which is [%s].", DTypeStr(y_data_type).c_str());
auto x_data_type = ctx.Input(kFirstInputIndex)->GetDataType();
switch (x_data_type) {
case DT_INT8:
return CountNonZeroComputeImpl<int8_t>(ctx);
break;
case DT_INT16:
return CountNonZeroComputeImpl<int16_t>(ctx);
break;
case DT_INT32:
return CountNonZeroComputeImpl<int32_t>(ctx);
break;
case DT_INT64:
return CountNonZeroComputeImpl<int64_t>(ctx);
break;
case DT_UINT8:
return CountNonZeroComputeImpl<uint8_t>(ctx);
break;
case DT_UINT16:
return CountNonZeroComputeImpl<uint16_t>(ctx);
break;
case DT_UINT32:
return CountNonZeroComputeImpl<uint32_t>(ctx);
break;
case DT_UINT64:
return CountNonZeroComputeImpl<uint64_t>(ctx);
break;
case DT_FLOAT16:
return CountNonZeroComputeImpl<Eigen::half>(ctx);
break;
case DT_FLOAT:
return CountNonZeroComputeImpl<float>(ctx);
break;
case DT_DOUBLE:
return CountNonZeroComputeImpl<double>(ctx);
break;
case DT_COMPLEX64:
return CountNonZeroComputeImpl<std::complex<float>>(ctx);
break;
case DT_COMPLEX128:
return CountNonZeroComputeImpl<std::complex<double>>(ctx);
break;
default:
KERNEL_LOG_ERROR("[CountNonZero] kernel data type [%s] not support.", DTypeStr(x_data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCountNonZero, CountNonZeroCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,31 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_COUNT_NONZERO_H
#define AICPU_KERNELS_NORMALIZED_COUNT_NONZERO_H
#include "cpu_ops_kernel.h"
namespace aicpu {
class CountNonZeroCpuKernel : public CpuKernel {
public:
~CountNonZeroCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,124 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "csr_sparse_matrix_to_dense.h"
#include <securec.h>
#include <complex>
#include <numeric>
#include <string>
#include "cpu_kernel_utils.h"
#include "cpu_types.h"
#include "kernel_log.h"
#include "status.h"
#include "utils/allocator_utils.h"
#include "utils/kernel_util.h"
using namespace std;
namespace {
const uint32_t kInputNum = 5;
const uint32_t kOutputNum = 1;
const char *CSRSparseMatrixToDense = "CSRSparseMatrixToDense";
#define SWITCH_CASE(_IDX_T, _VALUE_T, VALUE_T, FLAG, CTX) \
case _VALUE_T: \
switch (_IDX_T) { \
case DT_INT32: \
(FLAG) = DoCompute<int32_t, VALUE_T>(CTX); \
break; \
case DT_INT64: \
(FLAG) = DoCompute<int64_t, VALUE_T>(CTX); \
break; \
default: \
KERNEL_LOG_ERROR("CSRSparseMatrixToDense index type [%s] not support.", DTypeStr(_IDX_T).c_str()); \
return KERNEL_STATUS_PARAM_INVALID; \
} \
break;
} // namespace
namespace aicpu {
uint32_t CSRSparseMatrixToDenseCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CSRSparseMatrixToDense normal check failed.");
DataType indice_type = ctx.Input(0)->GetDataType();
DataType value_type = ctx.Input(4)->GetDataType();
uint32_t status = 0;
switch (value_type) {
SWITCH_CASE(indice_type, DT_FLOAT, float_t, status, ctx)
SWITCH_CASE(indice_type, DT_DOUBLE, double_t, status, ctx)
SWITCH_CASE(indice_type, DT_COMPLEX64, complex<float_t>, status, ctx)
SWITCH_CASE(indice_type, DT_COMPLEX128, complex<double_t>, status, ctx)
default:
KERNEL_LOG_ERROR("CSRSparseMatrixToDense values type [%s] not support.", DTypeStr(value_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
KERNEL_HANDLE_ERROR(status, "CSRSparseMatrixToDense compute failed!");
return KERNEL_STATUS_OK;
}
template <typename indiceT, typename valueT>
uint32_t CSRSparseMatrixToDenseCpuKernel::DoCompute(CpuKernelContext &ctx) {
indiceT batch_size = ctx.Input(1)->NumElements() - 1;
auto rank = ctx.Input(0)->NumElements();
int shift = (rank == 2) ? 0 : 1;
indiceT num_rows = *(static_cast<indiceT *>(ctx.Input(0)->GetData()) + shift);
indiceT num_cols = *(static_cast<indiceT *>(ctx.Input(0)->GetData()) + shift + 1);
indiceT *batch_ptrs = static_cast<indiceT *>(ctx.Input(1)->GetData());
indiceT *row_ptrs = static_cast<indiceT *>(ctx.Input(2)->GetData());
indiceT *col_ind = static_cast<indiceT *>(ctx.Input(3)->GetData());
valueT *values = static_cast<valueT *>(ctx.Input(4)->GetData());
auto output = ctx.Output(0);
auto output_shape = output->GetTensorShape();
if (rank == 2) {
output_shape->SetDimSizes({num_rows, num_cols});
} else {
output_shape->SetDimSizes({batch_size, num_rows, num_cols});
}
output->SetTensorShape(output_shape.get());
valueT *y_data = static_cast<valueT *>(ctx.Output(0)->GetData());
// use multi-thread
uint32_t min_core = 1;
uint32_t max_core = std::max(min_core, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
max_core = std::min(max_core, (uint32_t)batch_size);
if (max_core == 0) {
KERNEL_LOG_ERROR("Max core num cannot be zero.");
return KERNEL_STATUS_PARAM_INVALID;
}
auto shard = [&](int64_t start, int64_t end) {
for (int64_t batch_idx = start; batch_idx < end; batch_idx++) {
const int64_t dense_offset = batch_idx * num_rows * num_cols;
for (int64_t i = 0; i < num_rows * num_cols; ++i) {
y_data[dense_offset + i] = 0;
}
const int64_t csr_batch_offset = batch_ptrs[batch_idx];
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
const int64_t row_offset = batch_idx * (num_rows + 1) + row_idx;
const int64_t col_begin = row_ptrs[row_offset];
const int64_t col_end = row_ptrs[row_offset + 1];
for (int64_t i = col_begin; i < col_end; ++i) {
const int64_t col_idx = col_ind[csr_batch_offset + i];
y_data[dense_offset + (row_idx * num_cols) + col_idx] = values[csr_batch_offset + i];
}
}
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, batch_size / max_core, shard),
"CSRSparseMatrixToDense Compute failed.");
return KERNEL_STATUS_OK;
}
// register the opetaor
REGISTER_CPU_KERNEL(CSRSparseMatrixToDense, CSRSparseMatrixToDenseCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,35 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CSR_SPARSE_MATRIX_TO_DENSE_H_
#define AICPU_KERNELS_NORMALIZED_CSR_SPARSE_MATRIX_TO_DENSE_H_
#include "Eigen/Core"
#include "cpu_ops_kernel.h"
namespace aicpu {
class CSRSparseMatrixToDenseCpuKernel : public CpuKernel {
public:
~CSRSparseMatrixToDenseCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename indiceT, typename valueT>
uint32_t DoCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,229 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "csr_sparse_matrix_to_sparse_tensor.h"
#include <complex>
#include <iostream>
#include "cpu_kernel_utils.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kInputNum = 5;
const uint32_t kOutputNum = 3;
const char *CSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor";
// when input data size is more than kParallelDataNum, use Parallel func
const int64_t kParallelDataNum = 4;
const int64_t kParallelDataNumMid = 32;
const int DIM2 = 2;
const int DIM3 = 3;
} // namespace
namespace aicpu {
uint32_t CSRSparseMatrixToSparseTensorCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CSRSparseMatrixToSparseTensor normal check failed.");
Tensor *x_dense_shape = ctx.Input(0);
Tensor *x_batch_pointers = ctx.Input(1);
Tensor *x_row_pointers = ctx.Input(2);
Tensor *x_col_indices = ctx.Input(3);
Tensor *x_values = ctx.Input(4);
const int rank = x_dense_shape->NumElements();
if (rank != DIM2 && rank != DIM3) {
KERNEL_LOG_ERROR("CSR SparseMatrix must have rank 2 or 3.");
return KERNEL_STATUS_PARAM_INVALID;
}
auto x_row_pointers_shape = x_row_pointers->GetTensorShape();
auto x_col_indices_shape = x_col_indices->GetTensorShape();
auto x_values_shape = x_values->GetTensorShape();
if (x_col_indices_shape->NumElements() != x_values_shape->NumElements()) {
KERNEL_LOG_ERROR("Tensor x_col_indices&x_values's ranks mismatch.");
return KERNEL_STATUS_PARAM_INVALID;
}
auto x_dense_shape_data_type = x_dense_shape->GetDataType();
auto x_batch_pointers_data_type = x_batch_pointers->GetDataType();
auto x_row_pointers_data_type = x_row_pointers->GetDataType();
auto x_col_indices_data_type = x_col_indices->GetDataType();
if (x_col_indices_data_type != DT_INT32 && x_col_indices_data_type != DT_INT64) {
KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel data type [%s] not support.",
DTypeStr(x_col_indices_data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (x_dense_shape_data_type != x_col_indices_data_type || x_batch_pointers_data_type != x_col_indices_data_type ||
x_row_pointers_data_type != x_col_indices_data_type) {
KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel data type mismatch.");
return KERNEL_STATUS_PARAM_INVALID;
}
auto x_values_data_type = x_values->GetDataType();
uint32_t status;
switch (x_col_indices_data_type) {
case DT_INT32:
switch (x_values_data_type) {
case DT_FLOAT:
status = ComputeKernel<int32_t, float>(ctx);
break;
case DT_DOUBLE:
status = ComputeKernel<int32_t, double>(ctx);
break;
case DT_COMPLEX64:
status = ComputeKernel<int32_t, std::complex<float> >(ctx);
break;
case DT_COMPLEX128:
status = ComputeKernel<int32_t, std::complex<double> >(ctx);
break;
default:
KERNEL_LOG_ERROR(
"CSRSparseMatrixToSparseTensor kernel data type [%s] not "
"support.",
DTypeStr(x_values_data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
break;
case DT_INT64:
switch (x_values_data_type) {
case DT_FLOAT:
status = ComputeKernel<int64_t, float>(ctx);
break;
case DT_DOUBLE:
status = ComputeKernel<int64_t, double>(ctx);
break;
case DT_COMPLEX64:
status = ComputeKernel<int64_t, std::complex<float> >(ctx);
break;
case DT_COMPLEX128:
status = ComputeKernel<int64_t, std::complex<double> >(ctx);
break;
default:
KERNEL_LOG_ERROR(
"CSRSparseMatrixToSparseTensor kernel data type [%s] not "
"support.",
DTypeStr(x_values_data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
break;
default:
KERNEL_LOG_ERROR("data type of indices is not int32 or int64");
return KERNEL_STATUS_PARAM_INVALID;
}
if (status != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel compute failed.");
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTensorCpuKernel);
template <typename indicesT, typename dataT>
uint32_t CSRSparseMatrixToSparseTensorCpuKernel::ComputeKernel(CpuKernelContext &ctx) {
auto x_dense_shape = ctx.Input(0);
auto x_dense_shape_ptr = static_cast<indicesT *>(x_dense_shape->GetData());
auto dense_shape_ptr = static_cast<indicesT *>(ctx.Output(2)->GetData());
auto values_ptr = static_cast<dataT *>(ctx.Output(1)->GetData());
auto x_values = ctx.Input(4);
auto x_values_ptr = static_cast<dataT *>(x_values->GetData());
// Copy the SparseTensor's dense_shape and values from the CSRSparseMatrix.
for (int64_t i = 0; i < x_dense_shape->GetTensorShape()->NumElements(); i++) {
dense_shape_ptr[i] = x_dense_shape_ptr[i];
}
for (int64_t i = 0; i < x_values->GetTensorShape()->NumElements(); i++) {
values_ptr[i] = x_values_ptr[i];
}
const uint32_t batch_size = ctx.Input(1)->NumElements() - 1;
if (batch_size >= kParallelDataNum) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (batch_size <= kParallelDataNumMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > batch_size) {
max_core_num = batch_size;
}
auto sharder = [&](int64_t batch_begin, int64_t batch_end) {
SpecialCompute<indicesT>(batch_begin, batch_end, ctx);
};
if (max_core_num == 0) {
KERNEL_LOG_ERROR("max_core_num could not be 0.");
}
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, batch_size / max_core_num, sharder),
"CSRSparseMatrixToSparseTensor Compute failed.");
} else {
SpecialCompute<indicesT>(0, batch_size, ctx);
}
return KERNEL_STATUS_OK;
}
template <typename indicesT>
void CSRSparseMatrixToSparseTensorCpuKernel::SpecialCompute(int64_t batch_begin, int64_t batch_end,
CpuKernelContext &ctx) {
auto x_dense_shape = ctx.Input(0);
const int rank = x_dense_shape->NumElements();
auto x_dense_shape_ptr = static_cast<indicesT *>(x_dense_shape->GetData());
const int64_t num_rows = x_dense_shape_ptr[(rank == DIM2) ? 0 : 1];
auto x_batch_pointers_ptr = static_cast<indicesT *>(ctx.Input(1)->GetData());
auto x_row_pointers_ptr = static_cast<indicesT *>(ctx.Input(2)->GetData());
auto x_col_indices_ptr = static_cast<indicesT *>(ctx.Input(3)->GetData());
for (int64_t batch_idx = batch_begin; batch_idx < batch_end; ++batch_idx) {
const int64_t batch_offset = x_batch_pointers_ptr[batch_idx];
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
int64_t row_offset = batch_idx * (num_rows + 1) + row_idx;
// The column indices of the current row lie in the range:
// [x_row_pointers_ptr[row_offset], x_row_pointer_ptr[row_offset + 1]]
const int64_t col_begin = x_row_pointers_ptr[row_offset];
const int64_t col_end = x_row_pointers_ptr[row_offset + 1];
for (int64_t i = col_begin; i < col_end; ++i) {
const int64_t col_idx = x_col_indices_ptr[batch_offset + i];
const int64_t indices_offset = rank * (batch_offset + i);
IndicesCompute<indicesT>(ctx, indices_offset, batch_idx, row_idx, col_idx);
}
}
}
}
template <typename indicesT>
void CSRSparseMatrixToSparseTensorCpuKernel::IndicesCompute(CpuKernelContext &ctx, int64_t indices_offset,
const int64_t batch_idx, const int64_t row_idx,
const int64_t col_idx) {
const int rank = ctx.Input(0)->NumElements();
auto indices_ptr = static_cast<indicesT *>(ctx.Output(0)->GetData());
if (rank == DIM2) {
indices_ptr[indices_offset] = row_idx;
indices_ptr[++indices_offset] = col_idx;
} else { // rank == 3
indices_ptr[indices_offset] = batch_idx;
indices_ptr[++indices_offset] = row_idx;
indices_ptr[++indices_offset] = col_idx;
}
}
} // namespace aicpu

View File

@ -0,0 +1,41 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CSRSPARSEMATRIXTOSPARSETENSOR_H_
#define AICPU_KERNELS_NORMALIZED_CSRSPARSEMATRIXTOSPARSETENSOR_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class CSRSparseMatrixToSparseTensorCpuKernel : public CpuKernel {
public:
~CSRSparseMatrixToSparseTensorCpuKernel() = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename indicesT, typename dataT>
uint32_t ComputeKernel(CpuKernelContext &ctx);
template <typename indicesT>
void SpecialCompute(int64_t batch_begin, int64_t batch_end, CpuKernelContext &ctx);
template <typename indicesT>
void IndicesCompute(CpuKernelContext &ctx, int64_t indices_offset, const int64_t batch_idx, const int64_t row_idx,
const int64_t col_idx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,202 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cumprod.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kCumprodInputNum = 2;
const uint32_t kCumprodOutputNum = 1;
const int64_t paralled_data_size = 512 * 1024;
const char *kCumprod = "Cumprod";
#define CUMPROD_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = CumprodCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Cumprod kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t CumprodCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kCumprodInputNum, kCumprodOutputNum), "[%s] check input and output failed.",
kCumprod);
// parse params
KERNEL_HANDLE_ERROR(CumprodCheck(ctx), "[%s] check params failed.", kCumprod);
auto data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
CUMPROD_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
CUMPROD_COMPUTE_CASE(DT_FLOAT, float, ctx)
CUMPROD_COMPUTE_CASE(DT_DOUBLE, double, ctx)
CUMPROD_COMPUTE_CASE(DT_INT8, int8_t, ctx)
CUMPROD_COMPUTE_CASE(DT_INT16, int16_t, ctx)
CUMPROD_COMPUTE_CASE(DT_INT32, int32_t, ctx)
CUMPROD_COMPUTE_CASE(DT_INT64, int64_t, ctx)
CUMPROD_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
CUMPROD_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
CUMPROD_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
CUMPROD_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
CUMPROD_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
CUMPROD_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
default:
KERNEL_LOG_ERROR("Cumprod kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t CumprodCpuKernel::CumprodCheck(CpuKernelContext &ctx) {
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed.");
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.")
KERNEL_CHECK_NULLPTR(ctx.GetAttr("exclusive"), KERNEL_STATUS_PARAM_INVALID, "get exclusive failed.");
KERNEL_CHECK_NULLPTR(ctx.GetAttr("reverse"), KERNEL_STATUS_PARAM_INVALID, "get reverse failed.");
KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_INT32 || ctx.Input(1)->GetDataType() == DT_INT64),
KERNEL_STATUS_PARAM_INVALID,
"Data type of axis is not support, axis data type is [%u], only support int32 or int64.",
ctx.Input(1)->GetDataType());
KERNEL_CHECK_FALSE(ctx.Input(1)->NumElements() == 1, KERNEL_STATUS_PARAM_INVALID, "axis is out of shape")
auto axis_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
int32_t axis = *axis_data;
KERNEL_CHECK_FALSE((axis < ctx.Input(0)->GetTensorShape()->GetDims()), KERNEL_STATUS_PARAM_INVALID,
"axis is larger than input dims - 1");
KERNEL_CHECK_FALSE((axis >= -ctx.Input(0)->GetTensorShape()->GetDims()), KERNEL_STATUS_PARAM_INVALID,
"axis is lower than -input dims");
std::vector<int64_t> shape_input = ctx.Input(0)->GetTensorShape()->GetDimSizes();
std::vector<int64_t> shape_output = ctx.Output(0)->GetTensorShape()->GetDimSizes();
KERNEL_CHECK_FALSE((shape_input.size() != 0), KERNEL_STATUS_PARAM_INVALID,
"Input must be at least rank 1, got [%zu].", shape_input.size())
KERNEL_CHECK_FALSE((shape_input.size() == shape_output.size()), KERNEL_STATUS_PARAM_INVALID,
"The output shape size should be same as the output shape size")
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CumprodCpuKernel::CumprodCompute(CpuKernelContext &ctx) {
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto axis_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
int32_t axis = *axis_data;
bool exclusive = ctx.GetAttr("exclusive")->GetBool();
bool reverse = ctx.GetAttr("reverse")->GetBool();
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
auto shape = ctx.Input(0)->GetTensorShape();
const int64_t rank = shape->GetDims();
if (axis < 0) {
axis += shape->GetDims();
}
size_t inner = 1;
size_t outer = 1;
size_t depth = 1;
for (int32_t i = 0; i < rank; ++i) {
if (i < axis) {
inner *= shape->GetDimSize(i);
} else if (i > axis) {
outer *= shape->GetDimSize(i);
} else {
depth = shape->GetDimSize(i);
}
}
int64_t data_num = ctx.Input(0)->NumElements();
int64_t data_size = data_num * sizeof(T);
if (data_size <= paralled_data_size) {
for (size_t outer_index = 0; outer_index < outer; ++outer_index) {
size_t outer_index_adj;
if (reverse)
outer_index_adj = (outer - 1) - outer_index;
else
outer_index_adj = outer_index;
for (size_t inner_index = 0; inner_index < inner; inner_index++) {
auto multiplier = static_cast<T>(1);
size_t inner_index_adj;
if (reverse)
inner_index_adj = (inner - 1) - inner_index;
else
inner_index_adj = inner_index;
for (size_t depth_index = 0; depth_index < depth; depth_index++) {
size_t depth_index_adj;
if (reverse)
depth_index_adj = (depth - 1) - depth_index;
else
depth_index_adj = depth_index;
size_t index = outer_index_adj;
index += inner_index_adj * depth * outer;
index += depth_index_adj * outer;
if (exclusive) {
output_data[index] = multiplier;
multiplier *= input_data[index];
} else {
multiplier *= input_data[index];
output_data[index] = multiplier;
}
}
}
}
} else {
auto shard_cumprod = [&](size_t start, size_t ene) {
for (size_t outer_index = 0; outer_index < outer; ++outer_index) {
size_t outer_index_adj;
if (reverse) {
outer_index_adj = (outer - 1) - outer_index;
} else {
outer_index_adj = outer_index;
}
for (size_t inner_index = 0; inner_index < inner; inner_index++) {
auto multiplier = static_cast<T>(1);
size_t inner_index_adj;
if (reverse) {
inner_index_adj = (inner - 1) - inner_index;
} else {
inner_index_adj = inner_index;
}
for (size_t depth_index = 0; depth_index < depth; depth_index++) {
size_t depth_index_adj;
if (reverse) {
depth_index_adj = (depth - 1) - depth_index;
} else {
depth_index_adj = depth_index;
}
size_t index = outer_index_adj;
index += inner_index_adj * depth * outer;
index += depth_index_adj * outer;
if (exclusive) {
output_data[index] = multiplier;
multiplier *= input_data[index];
} else {
multiplier *= input_data[index];
output_data[index] = multiplier;
}
}
}
}
};
uint32_t min_core_num = 1;
size_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > outer) {
max_core_num = outer;
}
if (max_core_num == 0) {
KERNEL_LOG_ERROR("max_core_num could not be 0");
}
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, outer, outer / max_core_num, shard_cumprod),
"Cumprod Compute failed.");
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCumprod, CumprodCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CUMPROD_H_
#define AICPU_KERNELS_NORMALIZED_CUMPROD_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class CumprodCpuKernel : public CpuKernel {
public:
CumprodCpuKernel() = default;
~CumprodCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t CumprodCheck(CpuKernelContext &ctx);
template <typename T>
uint32_t CumprodCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -208,4 +208,4 @@ uint32_t CumulativeLogsumexpCpuKernel::CumulativeLogsumexpCompute(CpuKernelConte
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(KCumulativeLogsumexp, CumulativeLogsumexpCpuKernel);
} // namespace aicpu
} // namespace aicpu

View File

@ -35,4 +35,4 @@ class CumulativeLogsumexpCpuKernel : public CpuKernel {
uint32_t CumulativeLogsumexpCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif
#endif

View File

@ -0,0 +1,166 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "shuffle_channel.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include <vector>
namespace {
const uint32_t kInputNum = 1;
const uint32_t kOutputNum = 1;
const char *kShuffleChannel = "ShuffleChannel";
const int64_t minDimSize = 3;
#define SHUFFLE_CHANNEL_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = ShuffleChannelCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Shuffle Channel kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t ShuffleChannelCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "ShuffleChannel check input and output number failed.");
KERNEL_HANDLE_ERROR(ShuffleChannelParamCheck(ctx), "ShuffleChannel check params failed.");
auto data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT8, int8_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT16, int16_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT32, int32_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT64, int64_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_FLOAT, float, ctx)
default:
KERNEL_LOG_ERROR("Shuffle kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t ShuffleChannelCpuKernel::ShuffleChannelParamCheck(CpuKernelContext &ctx) {
// the non null of input_0, input_1, output has been verified in NormalCheck
Tensor *input = ctx.Input(0);
auto group = 1;
if (ctx.GetAttr("group")) {
group = ctx.GetAttr("group")->GetInt();
}
int64_t c = input->GetTensorShape()->GetDimSize(1);
KERNEL_CHECK_FALSE(input->GetTensorShape()->GetDims() >= minDimSize, KERNEL_STATUS_PARAM_INVALID,
"ShuffleChannel expect input with > 2 dims.")
KERNEL_CHECK_FALSE(group > 0, KERNEL_STATUS_PARAM_INVALID, "Number of groups to divide channels in must be positive.")
KERNEL_CHECK_FALSE((c % group) == 0, KERNEL_STATUS_PARAM_INVALID, "Number of channels must be divisible by groups")
KERNEL_LOG_DEBUG(
"ShuffleChannelCpuKernel[%s], input: size[%llu];"
"output: size[%llu].",
ctx.GetOpType().c_str(), input->GetDataSize(), ctx.Output(0)->GetDataSize());
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ShuffleChannelCpuKernel::ShuffleChannelCompute(CpuKernelContext &ctx) {
Tensor *input = ctx.Input(0);
Tensor *output = ctx.Output(0);
auto group = 1;
if (ctx.GetAttr("group")) {
group = ctx.GetAttr("group")->GetInt();
}
auto shape = input->GetTensorShape();
int64_t b = shape->GetDimSize(0);
int64_t c = shape->GetDimSize(1);
int64_t dims = shape->GetDims();
if (group == 0) {
KERNEL_LOG_ERROR("group divided can not be zero");
return KERNEL_STATUS_PARAM_INVALID;
}
int64_t oc = c / group;
auto out = reinterpret_cast<T *>(output->GetData());
auto in = reinterpret_cast<T *>(input->GetData());
int64_t loc_out = 0;
int64_t loc_in = 0;
int64_t area = 1;
for (int64_t i = 2; i < dims; i++) {
area = area * shape->GetDimSize(i);
}
/*
view the shape to n g c/g h*w,and transpose dim 1 and dim 2
*/
std::vector<int64_t> temp_shape;
temp_shape.push_back(b);
temp_shape.push_back(group);
temp_shape.push_back(oc);
temp_shape.push_back(area);
std::vector<int64_t> temp_loc = {0, 0, 0, 0};
bool can_plus = false;
int64_t dim = 0;
while (true) {
for (dim = 0; dim <= minDimSize; dim++) {
if (dim == minDimSize) {
loc_in = loc_in + temp_loc[dim] * 1;
loc_out = loc_out + temp_loc[dim] * 1;
} else if (dim == (minDimSize - 1)) {
loc_in = loc_in + temp_loc[dim] * area;
loc_out = loc_out + temp_loc[1] * area;
} else if (dim == 1) {
loc_in = loc_in + temp_loc[dim] * oc * area;
loc_out = loc_out + temp_loc[minDimSize - 1] * group * area;
} else if (dim == 0) {
loc_in = loc_in + temp_loc[dim] * group * oc * area;
loc_out = loc_out + temp_loc[dim] * group * oc * area;
}
}
*(out + loc_out) = *(in + loc_in);
loc_in = 0;
loc_out = 0;
can_plus = false;
for (dim = 0; dim <= minDimSize; dim++) {
if (temp_loc[dim] < (temp_shape[dim] - 1)) {
can_plus = true;
break;
}
}
if (!can_plus) {
break;
}
for (dim = minDimSize; dim >= 0; dim--) {
if (temp_loc[dim] == (temp_shape[dim] - 1)) {
temp_loc[dim] = 0;
} else {
temp_loc[dim] = temp_loc[dim] + 1;
break;
}
}
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kShuffleChannel, ShuffleChannelCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,39 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_SHUFFLE_CHANNEL_H_
#define AICPU_KERNELS_NORMALIZED_SHUFFLE_CHANNEL_H_
#include "cpu_ops_kernel.h"
#include "utils/bcast.h"
namespace aicpu {
class ShuffleChannelCpuKernel : public CpuKernel {
public:
ShuffleChannelCpuKernel() = default;
~ShuffleChannelCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ShuffleChannelParamCheck(CpuKernelContext &ctx);
template <typename T>
uint32_t ShuffleChannelCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -52,9 +52,24 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
kFSEDecodeOpName};
static const std::set<std::string> kMigrateAicpuKernelOps = {mindspore::kAdaptiveAvgPool2DV1OpName,
mindspore::kAdaptiveAvgPool2DGradV1OpName,
mindspore::kBucketizeOpName,
mindspore::kCacheSwapTableOpName,
mindspore::kCauchyOpName,
mindspore::kChannelShuffleOpName,
mindspore::kCholeskyGradOpName,
mindspore::kCholeskyInverseOpName,
mindspore::kCholeskySolveOpName,
mindspore::kCol2imOpName,
mindspore::kCombinedNonMaxSuppressionOpName,
mindspore::kComplexOpName,
mindspore::kComplexAbsOpName,
mindspore::kConcatOpName,
mindspore::kCosOpName,
mindspore::kCountNonZeroOpName,
mindspore::kCumulativeLogsumexpOpName,
mindspore::kCumprodOpName,
mindspore::kCSRSparseMatrixToDenseOpName,
mindspore::kCSRSparseMatrixToSparseTensorOpName,
mindspore::kDataFormatVecPermuteOpName,
mindspore::kFillOpName,
mindspore::kLogMatrixDeterminantOpName,

View File

@ -184,6 +184,20 @@ from .nan_to_num import _nan_to_num_aicpu
from .qr import _qr_aicpu
from .col2im import _col2im_aicpu
from .matrix_solve_ls import _matrix_solve_ls_aicpu
from .cauchy import _cauchy_aicpu
from .bucketize import _bucketize_aicpu
from .channel_shuffle import _channel_shuffle_aicpu
from .choleskygrad import _choleskygrad_aicpu
from .cholesky_inverse import _cholesky_inverse_aicpu
from .cholesky_solve import _cholesky_solve_aicpu
from .combined_non_max_suppression import _combined_non_max_suppression_aicpu
from .complex import _complex_aicpu
from .complex_abs import _complex_abs_aicpu
from .concat import _concat_aicpu
from .cos import _cos_aicpu
from .count_nonzero import _count_nonzero_aicpu
from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
from .cumprod import _cumprod_aicpu
from .exp import _exp_aicpu
from .matrix_triangular_solve import _matrix_triangular_solve_aicpu
from .maximum_grad_grad import _maximum_grad_grad_aicpu