forked from mindspore-Ecosystem/mindspore
add aicpu_ops
This commit is contained in:
parent
88833bada9
commit
f00862beb2
|
@ -86,6 +86,7 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "redundantAssignment"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constArgument"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "unknownMacro"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "nullPointerRedundantCheck"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "variableScope"
|
||||
|
|
|
@ -132,3 +132,6 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/operators"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/comma"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/indentation_namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/line_length"
|
||||
|
||||
|
|
|
@ -282,6 +282,9 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d.cc:aicpu::MaxUnpool2DCpuKernel::MaxUnpool2DCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_solve_ls.cc:aicpu::MatrixSolveLsCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/col2im.cc:aicpu::Col2imCpuKernel::Col2imParamCheck
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc:aicpu::CSRSparseMatrixToSparseTensorCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc:aicpu::CumprodCpuKernel::CumprodCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc:aicpu::CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCheck
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd_update.cc:aicpu::ScatterNdUpdateCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/ragged_tensor_to_sparse.cc:aicpu::RaggedTensorToSparseCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_3d_grad.cc:aicpu::MaxUnpool3DGradCpuKernel::MaxUnpool3DGradCompute
|
||||
|
|
|
@ -149,15 +149,21 @@ constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2";
|
|||
constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3";
|
||||
constexpr auto kBpropCutOpName = "bprop_cut";
|
||||
constexpr auto kBroadcastToOpName = "BroadcastTo";
|
||||
constexpr auto kBucketizeOpName = "Bucketize";
|
||||
constexpr auto kIFMROpName = "IFMR";
|
||||
constexpr auto kBroadcastToDOpName = "BroadcastToD";
|
||||
constexpr auto kCacheSwapTableOpName = "CacheSwapTable";
|
||||
constexpr auto kCallOpName = "call";
|
||||
constexpr auto kCauchyOpName = "Cauchy";
|
||||
constexpr auto kCastOpName = "Cast";
|
||||
constexpr auto kCentralizationOpName = "Centralization";
|
||||
constexpr auto kCeLUOpName = "CeLU";
|
||||
constexpr auto kCeluV2OpName = "CeluV2";
|
||||
constexpr auto kChannelShuffleOpName = "ChannelShuffle";
|
||||
constexpr auto kCheckNumericsOpName = "CheckNumerics";
|
||||
constexpr auto kCholeskyGradOpName = "CholeskyGrad";
|
||||
constexpr auto kCholeskyInverseOpName = "CholeskyInverse";
|
||||
constexpr auto kCholeskySolveOpName = "CholeskySolve";
|
||||
constexpr auto kClearZeroOpName = "ClearZero";
|
||||
constexpr auto kClipBoxesOpName = "kClipBoxes";
|
||||
constexpr auto kClipBoxesDOpName = "kClipBoxesD";
|
||||
|
@ -165,8 +171,11 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum";
|
|||
constexpr auto kClipByValueOpName = "ClipByValue";
|
||||
constexpr auto kCoalesceOpName = "Coalesce";
|
||||
constexpr auto kCol2imOpName = "Col2im";
|
||||
constexpr auto kCombinedNonMaxSuppressionOpName = "CombinedNonMaxSuppression";
|
||||
constexpr auto kCombineMomentumOpName = "CombineMomentum";
|
||||
constexpr auto kCombineMomentumWeightOpName = "CombineMomentumWeight";
|
||||
constexpr auto kComplexOpName = "Complex";
|
||||
constexpr auto kComplexAbsOpName = "ComplexAbs";
|
||||
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
|
||||
constexpr auto kConcatOpName = "Concat";
|
||||
constexpr auto kConcatDOpName = "ConcatD";
|
||||
|
@ -193,6 +202,8 @@ constexpr auto kCropAndResizeOpName = "CropAndResize";
|
|||
constexpr auto kCropAndResizeDOpName = "CropAndResizeD";
|
||||
constexpr auto kConvBN1OpName = "ConvBN1";
|
||||
constexpr auto kCOO2CSROpName = "COO2CSR";
|
||||
constexpr auto kCosOpName = "Cos";
|
||||
constexpr auto kCountNonZeroOpName = "CountNonZero";
|
||||
constexpr auto kCSR2COOOpName = "CSR2COO";
|
||||
constexpr auto kCSRDivOpName = "CSRDiv";
|
||||
constexpr auto kCSRGatherOpName = "CSRGather";
|
||||
|
@ -200,6 +211,7 @@ constexpr auto kCSRMMOpName = "CSRMM";
|
|||
constexpr auto kCSRMulOpName = "CSRMul";
|
||||
constexpr auto kCSRMVOpName = "CSRMV";
|
||||
constexpr auto kCSRReduceSumOpName = "CSRReduceSum";
|
||||
constexpr auto kCSRSparseMatrixToDenseOpName = "CSRSparseMatrixToDense";
|
||||
constexpr auto kCSRSparseMatrixToSparseTensorOpName = "CSRSparseMatrixToSparseTensor";
|
||||
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
|
||||
constexpr auto kCumprodOpName = "Cumprod";
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "bucketize.h"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 1;
|
||||
const char *kBucketize = "Bucketize";
|
||||
const int64_t kParallelDataNumSameShape = 64 * 1024;
|
||||
const int64_t kParallelDataNumSameShapeMid = 35 * 1024;
|
||||
|
||||
#define BUCKETIZE_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = BucketizeCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Bucketize kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
int64_t get_tensor_length(aicpu::Tensor *t) {
|
||||
std::vector<int64_t> dim_sizes = t->GetTensorShape()->GetDimSizes();
|
||||
int64_t length = 1;
|
||||
for (auto x : dim_sizes) {
|
||||
length *= x;
|
||||
}
|
||||
return length;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t BucketizeCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// normal check
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Bucketize check input and output number failed.");
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
|
||||
KERNEL_CHECK_NULLPTR(ctx.GetAttr("boundaries"), KERNEL_STATUS_PARAM_INVALID, "Get boundaries failed")
|
||||
|
||||
// check input datatype
|
||||
Tensor *input = ctx.Input(kFirstInputIndex);
|
||||
DataType dt_input = input->GetDataType();
|
||||
KERNEL_CHECK_FALSE((dt_input == DT_FLOAT || dt_input == DT_INT32 || dt_input == DT_INT64 || dt_input == DT_DOUBLE),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input data type must DT_FLOAT or DT_INT32 or DT_INT64 or DT_DOUBLE,"
|
||||
"but got data type[%s].",
|
||||
DTypeStr(dt_input).c_str());
|
||||
|
||||
// check output datatype
|
||||
Tensor *output = ctx.Output(kFirstOutputIndex);
|
||||
DataType dt_output = output->GetDataType();
|
||||
KERNEL_CHECK_FALSE((dt_output == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Output data type must DT_INT32, but got data type[%s].", DTypeStr(dt_output).c_str());
|
||||
|
||||
auto input_sizes = input->GetTensorShape()->GetDimSizes();
|
||||
auto output_sizes = output->GetTensorShape()->GetDimSizes();
|
||||
KERNEL_CHECK_FALSE((input_sizes == output_sizes), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The tensor shape of input [%s] need be same with "
|
||||
"output [%s].",
|
||||
VectorToString(input_sizes).c_str(), VectorToString(output_sizes).c_str());
|
||||
|
||||
switch (data_type) {
|
||||
BUCKETIZE_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
BUCKETIZE_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
BUCKETIZE_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
BUCKETIZE_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Bucketize kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t BucketizeCpuKernel::BucketizeCompute(CpuKernelContext &ctx) {
|
||||
const int64_t data_num = get_tensor_length(ctx.Input(0));
|
||||
auto boundaries = ctx.GetAttr("boundaries");
|
||||
std::vector<float> boundaries_data = boundaries->GetListFloat();
|
||||
std::sort(boundaries_data.begin(), boundaries_data.end());
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<int32_t *>(ctx.Output(0)->GetData());
|
||||
|
||||
if (data_num >= kParallelDataNumSameShape) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
|
||||
if (data_num <= kParallelDataNumSameShapeMid) {
|
||||
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
|
||||
}
|
||||
|
||||
if (max_core_num > data_num) {
|
||||
max_core_num = data_num;
|
||||
}
|
||||
auto sharder_bucketize = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
|
||||
output_data[i] = first_bigger_it - boundaries_data.begin();
|
||||
}
|
||||
};
|
||||
CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_bucketize);
|
||||
} else {
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]);
|
||||
output_data[i] = first_bigger_it - boundaries_data.begin();
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kBucketize, BucketizeCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_BUCKETIZE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_BUCKETIZE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class BucketizeCpuKernel : public CpuKernel {
|
||||
public:
|
||||
BucketizeCpuKernel() = default;
|
||||
~BucketizeCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
template <typename T>
|
||||
static uint32_t BucketizeCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cauchy.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include <vector>
|
||||
#include <random>
|
||||
|
||||
namespace {
|
||||
const char *kCauchy = "Cauchy";
|
||||
const int64_t kParallelDataNums = 64 * 1024;
|
||||
const uint32_t knum = 2;
|
||||
} // namespace
|
||||
|
||||
#define CAUCHY_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = CauchyCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Cauchy kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
namespace aicpu {
|
||||
uint32_t CauchyCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
Tensor *output_tensor = ctx.Output(0);
|
||||
auto output_dtype = output_tensor->GetDataType();
|
||||
KERNEL_CHECK_NULLPTR(output_tensor, KERNEL_STATUS_INNER_ERROR, "CauchyCompute check output_tensor is nullptr.");
|
||||
|
||||
switch (output_dtype) {
|
||||
CAUCHY_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
CAUCHY_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Cauchy kernel data type [%s] not support.", DTypeStr(output_dtype).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
uint32_t CauchyCpuKernel::CauchyCompute(CpuKernelContext &ctx) {
|
||||
AttrValue *median = ctx.GetAttr("median");
|
||||
if (median != nullptr) {
|
||||
median_ = median->GetFloat();
|
||||
}
|
||||
|
||||
AttrValue *sigma = ctx.GetAttr("sigma");
|
||||
if (sigma != nullptr) {
|
||||
sigma_ = sigma->GetFloat();
|
||||
}
|
||||
|
||||
Tensor *y_tensor = ctx.Output(0);
|
||||
|
||||
AttrValue *output_size_attr = ctx.GetAttr("size");
|
||||
KERNEL_CHECK_NULLPTR(output_size_attr, KERNEL_STATUS_PARAM_INVALID, "CauchyCompute get size failed.");
|
||||
std::vector<int64_t> output_size = ctx.GetAttr("size")->GetListInt();
|
||||
if (output_size.empty()) {
|
||||
KERNEL_LOG_ERROR("CauchyCompute get size is empty.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto y_shape = y_tensor->GetTensorShape();
|
||||
y_shape->SetDimSizes(output_size);
|
||||
|
||||
int64_t y_num = y_tensor->NumElements();
|
||||
KERNEL_CHECK_NULLPTR(y_tensor->GetData(), KERNEL_STATUS_INNER_ERROR, "CauchyCompute check output_data is nullptr.");
|
||||
T *y_data = static_cast<T *>(y_tensor->GetData());
|
||||
std::default_random_engine generator(std::random_device{}());
|
||||
|
||||
std::cauchy_distribution<float> cauchy_d(median_, sigma_);
|
||||
uint32_t max_core_num = 1;
|
||||
if (y_num >= kParallelDataNums) {
|
||||
max_core_num = std::max(max_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - knum);
|
||||
if (max_core_num > y_num) {
|
||||
max_core_num = y_num;
|
||||
}
|
||||
}
|
||||
|
||||
auto Cauchy_d = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
float data = cauchy_d(generator);
|
||||
y_data[i] = static_cast<T>(data);
|
||||
}
|
||||
};
|
||||
|
||||
uint32_t ret = CpuKernelUtils::ParallelFor(ctx, y_num, y_num / max_core_num, Cauchy_d);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("CpuKernelUtils::ParallelFor failed.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
KERNEL_LOG_INFO("CauchyCpuKernel::ComputeCauchy end.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kCauchy, CauchyCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CAUCHY_WINDOW_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CAUCHY_WINDOW_H_
|
||||
#define EIGEN_USE_THREADS
|
||||
#define EIGEN_USE_SIMPLE_THREAD_POOL
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace aicpu {
|
||||
class CauchyCpuKernel : public CpuKernel {
|
||||
public:
|
||||
CauchyCpuKernel() = default;
|
||||
~CauchyCpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t CauchyCompute(CpuKernelContext &ctx);
|
||||
|
||||
float median_ = 0.0;
|
||||
float sigma_ = 1.0;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cholesky_grad.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include "cpu_kernel_utils.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *CholeskyGrad = "CholeskyGrad";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CholeskyGradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CholeskyGrad check input and output number failed.");
|
||||
Tensor *input0 = ctx.Input(0);
|
||||
Tensor *input1 = ctx.Input(1);
|
||||
Tensor *output0 = ctx.Output(0);
|
||||
if (input0->GetDataSize() == 0 || input1->GetDataSize() == 0) {
|
||||
KERNEL_LOG_ERROR("[%s] Input tensor is empty tensor.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto shape0 = input0->GetTensorShape();
|
||||
auto shape1 = input1->GetTensorShape();
|
||||
auto shape2 = output0->GetTensorShape();
|
||||
if (shape0->GetDims() != shape1->GetDims() || shape1->GetDims() != shape2->GetDims()) {
|
||||
KERNEL_LOG_ERROR("[%s] Inputs and Output tensors should have same dims.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto dims = shape0->GetDims();
|
||||
if (shape0->GetDimSize(dims - 1) != shape0->GetDimSize(dims - 2)) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input0 is not square.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if ((shape0->GetDimSize(dims - 1) != shape1->GetDimSize(dims - 1)) ||
|
||||
(shape0->GetDimSize(dims - 2) != shape1->GetDimSize(dims - 2))) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input0&input1's shape mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if ((shape0->GetDimSize(dims - 1) != shape2->GetDimSize(dims - 1)) ||
|
||||
(shape0->GetDimSize(dims - 2) != shape2->GetDimSize(dims - 2))) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input0&output0's shape mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto data_type_0 = input0->GetDataType();
|
||||
auto data_type_1 = input1->GetDataType();
|
||||
auto data_type_2 = output0->GetDataType();
|
||||
if (data_type_0 != data_type_1 || data_type_0 != data_type_2) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor data type mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (data_type_0 != DT_FLOAT && data_type_0 != DT_DOUBLE) {
|
||||
KERNEL_LOG_ERROR("CholeskyGrad kernel data type [%u] not support.", data_type_0);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (data_type_0 == DT_FLOAT) {
|
||||
return ComputeKernel<float>(ctx, true);
|
||||
} else {
|
||||
return ComputeKernel<double>(ctx, true);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t CholeskyGradCpuKernel::ComputeKernel(CpuKernelContext &ctx, const bool &reverse) {
|
||||
auto dims = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
auto lptr = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto gradptr = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto outputptr = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
int n = ctx.Input(0)->GetTensorShape()->GetDimSize(dims - 1);
|
||||
int64_t data_num = ctx.Input(0)->NumElements();
|
||||
const int64_t mat_size = n * n;
|
||||
const int64_t batch = data_num / mat_size;
|
||||
const int64_t kParallelDataNum = 16 * mat_size;
|
||||
const int64_t kParallelDataNumMid = 72 * mat_size;
|
||||
if (data_num >= kParallelDataNum) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
|
||||
if (data_num <= kParallelDataNumMid) {
|
||||
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
|
||||
}
|
||||
|
||||
auto sharder_cholesky_grad = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
ComputeMatrix(lptr + i * mat_size, gradptr + i * mat_size, outputptr + i * mat_size, n);
|
||||
}
|
||||
};
|
||||
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch, batch / max_core_num, sharder_cholesky_grad),
|
||||
"CholeskyGrad Compute failed.");
|
||||
|
||||
} else {
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
ComputeMatrix(lptr + i * mat_size, gradptr + i * mat_size, outputptr + i * mat_size, n);
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CholeskyGradCpuKernel::ComputeMatrix(T *lptr, T *gradptr, T *outputptr, int64_t n) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> eigengrad(n, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> eigenl(n, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> output_matrix(n, n);
|
||||
for (int i = 0; i < n * n; i++) {
|
||||
*(eigengrad.data() + i) = *(gradptr + i);
|
||||
*(eigenl.data() + i) = *(lptr + i);
|
||||
}
|
||||
|
||||
// Algorithm only depends on lower triangular half on input_matrix_l.
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> input_matrix_l =
|
||||
eigenl.template triangularView<Eigen::Lower>();
|
||||
// Algorithm only depends on lower triangular half on input_matrix_grad.
|
||||
output_matrix = eigengrad.template triangularView<Eigen::Lower>();
|
||||
|
||||
const int64_t kMatrixSize = input_matrix_l.rows();
|
||||
const int64_t kMaxBlockSize = 32;
|
||||
|
||||
for (int64_t block_end = kMatrixSize; block_end > 0; block_end -= kMaxBlockSize) {
|
||||
const int64_t block_begin = std::max(int64_t{0}, block_end - kMaxBlockSize);
|
||||
const int64_t block_size = block_end - block_begin;
|
||||
const int64_t trailing_size = kMatrixSize - block_end;
|
||||
|
||||
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
|
||||
auto B_bar = output_matrix.block(block_end, 0, trailing_size, block_begin);
|
||||
|
||||
auto C = input_matrix_l.block(block_end, block_begin, trailing_size, block_size);
|
||||
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, block_size);
|
||||
|
||||
auto D = input_matrix_l.block(block_begin, block_begin, block_size, block_size);
|
||||
auto D_bar = output_matrix.block(block_begin, block_begin, block_size, block_size);
|
||||
|
||||
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
|
||||
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin);
|
||||
|
||||
C_bar = D.adjoint().template triangularView<Eigen::Upper>().solve(C_bar.adjoint()).adjoint();
|
||||
D_bar -= (C_bar.adjoint() * C).template triangularView<Eigen::Lower>();
|
||||
B_bar -= C_bar * R;
|
||||
R_bar -= C_bar.adjoint() * B;
|
||||
CholeskyGradUnblocked<T>(D, D_bar);
|
||||
R_bar -= (D_bar + D_bar.adjoint()) * R;
|
||||
}
|
||||
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval();
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
*(outputptr + i * n + j) = output_matrix(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CholeskyGradCpuKernel::CholeskyGradUnblocked(
|
||||
const Eigen::Ref<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> &l_block,
|
||||
Eigen::Ref<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> grad_block) {
|
||||
const int64_t kMatrixSize = l_block.rows();
|
||||
for (int64_t k = kMatrixSize - 1; k >= 0; k--) {
|
||||
const int64_t number_rows_B = kMatrixSize - (k + 1);
|
||||
const int64_t number_rows_r_stack_B = number_rows_B + 1;
|
||||
|
||||
auto r = l_block.block(k, 0, 1, k);
|
||||
auto r_bar = grad_block.block(k, 0, 1, k);
|
||||
auto d = l_block(k, k); // This needs to be a scalar rather than a view.
|
||||
auto d_bar = grad_block.block(k, k, 1, 1);
|
||||
// B is not included explicitly because it is not used on its own.
|
||||
auto B_bar = grad_block.block(k + 1, 0, number_rows_B, k);
|
||||
auto c = l_block.block(k + 1, k, number_rows_B, 1);
|
||||
auto c_bar = grad_block.block(k + 1, k, number_rows_B, 1);
|
||||
// Result of vertical stacking d_bar and c_bar.
|
||||
auto d_stack_c_bar = grad_block.block(k, k, number_rows_r_stack_B, 1);
|
||||
// Result of vertical stacking of r and B.
|
||||
auto r_stack_B = l_block.block(k, 0, number_rows_r_stack_B, k);
|
||||
d_bar -= (c.adjoint() * c_bar) / d;
|
||||
d_stack_c_bar /= d;
|
||||
r_bar -= d_stack_c_bar.adjoint() * r_stack_B;
|
||||
B_bar -= c_bar * r;
|
||||
d_bar /= 2.;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(CholeskyGrad, CholeskyGradCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKY_GRAD_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CHOLESKY_GRAD_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
#include <Eigen/Dense>
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include <unsupported/Eigen/MatrixFunctions>
|
||||
|
||||
namespace aicpu {
|
||||
class CholeskyGradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
CholeskyGradCpuKernel() = default;
|
||||
~CholeskyGradCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t ComputeKernel(CpuKernelContext &ctx, const bool &reverse);
|
||||
|
||||
template <typename T>
|
||||
void ComputeMatrix(T *lptr, T *gradptr, T *outputptr, int64_t n);
|
||||
|
||||
template <typename T>
|
||||
void CholeskyGradUnblocked(
|
||||
const Eigen::Ref<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> &l_block,
|
||||
Eigen::Ref<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> grad_block);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cholesky_inverse.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <iostream>
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 1;
|
||||
const uint32_t dimension = 2;
|
||||
const char *kCholeskyInverse = "CholeskyInverse";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CholeskyInverseCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check CholeskyInverse params failed.");
|
||||
Tensor *input = ctx.Input(0);
|
||||
KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.");
|
||||
Tensor *output = ctx.Output(0);
|
||||
auto inputShape = input->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(inputShape, KERNEL_STATUS_PARAM_INVALID, "Get inputShape failed.");
|
||||
AttrValue *upper = ctx.GetAttr("upper");
|
||||
KERNEL_CHECK_NULLPTR(upper, KERNEL_STATUS_PARAM_INVALID, "Get upper failed.");
|
||||
KERNEL_LOG_DEBUG(
|
||||
"CholeskyInverseCpuKernel[%s], input: size[%llu];"
|
||||
"output: size[%llu].",
|
||||
ctx.GetOpType().c_str(), input->GetDataSize(), output->GetDataSize());
|
||||
auto input_dims = inputShape->GetDims();
|
||||
if (input_dims != dimension) {
|
||||
KERNEL_LOG_ERROR("CholeskyInverse input dim must be 2!");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else if (inputShape->GetDimSize(input_dims - 2) != inputShape->GetDimSize(input_dims - 1)) {
|
||||
KERNEL_LOG_ERROR("CholeskyInverse input matrix must be square matrix!");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
DataType data_type = ctx.Input(0)->GetDataType();
|
||||
switch (data_type) {
|
||||
case DT_FLOAT:
|
||||
return CholeskyInverseCompute<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return CholeskyInverseCompute<double>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("CholeskyInverse kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
uint32_t CholeskyInverseCpuKernel::CholeskyInverseCompute(CpuKernelContext &ctx) {
|
||||
auto input_x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
auto inputShape = ctx.Input(0)->GetTensorShape();
|
||||
int64_t n = inputShape->GetDimSize(0);
|
||||
typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXd;
|
||||
Eigen::Map<MatrixXd> A(input_x, n, n);
|
||||
MatrixXd result;
|
||||
AttrValue *upper = ctx.GetAttr("upper");
|
||||
bool val = upper->GetBool();
|
||||
if (val) {
|
||||
result = (A.transpose() * A).inverse();
|
||||
} else {
|
||||
result = (A * A.transpose()).inverse();
|
||||
}
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
*(output_y + i * n + j) = result(i, j);
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kCholeskyInverse, CholeskyInverseCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,19 @@
|
|||
#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKYINVERSE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CHOLESKYINVERSE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class CholeskyInverseCpuKernel : public CpuKernel {
|
||||
public:
|
||||
CholeskyInverseCpuKernel() = default;
|
||||
~CholeskyInverseCpuKernel() = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static uint32_t CholeskyInverseCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cholesky_solve.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <Eigen/Dense>
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *CholeskySolve = "CholeskySolve";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CholeskySolveCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CholeskySolve check input and output failed.");
|
||||
|
||||
Tensor *input_x1 = ctx.Input(0);
|
||||
AttrValue *upper = ctx.GetAttr("upper");
|
||||
bool upperinfo = (upper == nullptr) ? false : upper->GetBool();
|
||||
auto data_type_x1 = input_x1->GetDataType();
|
||||
|
||||
switch (data_type_x1) {
|
||||
case DT_FLOAT:
|
||||
return ComputeKernel<float>(ctx, upperinfo);
|
||||
case DT_DOUBLE:
|
||||
return ComputeKernel<double>(ctx, upperinfo);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("CholeskySolve kernel data type [%s] not support.", DTypeStr(data_type_x1).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(CholeskySolve, CholeskySolveCpuKernel);
|
||||
|
||||
template <typename T>
|
||||
uint32_t CholeskySolveCpuKernel::ComputeKernel(CpuKernelContext &ctx, const bool &upper) {
|
||||
auto rhsptr = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto lhsptr = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto outptr = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
size_t batch_size = 1;
|
||||
std::vector<int64_t> dims = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
size_t dimsnum = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
size_t dim = dims[dimsnum - 2];
|
||||
size_t rhs_dim = dims[dimsnum - 1];
|
||||
if (dimsnum == 3) {
|
||||
batch_size = dims[dimsnum - 3];
|
||||
}
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RHS(dim, rhs_dim);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> LHS(dim, dim);
|
||||
for (size_t k = 0; k < batch_size; k++) {
|
||||
for (size_t i = 0; i < dim * rhs_dim; i++) {
|
||||
RHS.data()[i] = rhsptr[k * dim * rhs_dim + i];
|
||||
}
|
||||
for (size_t i = 0; i < dim * dim; i++) {
|
||||
LHS.data()[i] = lhsptr[k * dim * dim + i];
|
||||
}
|
||||
if (!upper) {
|
||||
LHS.template triangularView<Eigen::Lower>().solveInPlace(RHS);
|
||||
LHS.adjoint().template triangularView<Eigen::Upper>().solveInPlace(RHS);
|
||||
} else {
|
||||
LHS.adjoint().template triangularView<Eigen::Lower>().solveInPlace(RHS);
|
||||
LHS.template triangularView<Eigen::Upper>().solveInPlace(RHS);
|
||||
}
|
||||
for (size_t i = 0; i < dim * rhs_dim; i++) {
|
||||
outptr[k * dim * rhs_dim + i] = RHS.data()[i];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKY_SOLVE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CHOLESKY_SOLVE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CholeskySolveCpuKernel : public CpuKernel {
|
||||
public:
|
||||
CholeskySolveCpuKernel() = default;
|
||||
~CholeskySolveCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t ComputeKernel(CpuKernelContext &ctx, const bool &upper);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,447 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "combined_non_max_suppression.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 6;
|
||||
const uint32_t kOutputNum = 4;
|
||||
const char *kCombinedNonMaxSuppression = "CombinedNonMaxSuppression";
|
||||
|
||||
void alloc_zeros(float *arr, int arr_len) {
|
||||
for (int i = 0; i < arr_len; i++) {
|
||||
arr[i] = (float)0.0;
|
||||
}
|
||||
}
|
||||
|
||||
void alloc_zeros(int *arr, int arr_len) {
|
||||
for (int i = 0; i < arr_len; i++) {
|
||||
arr[i] = 0;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
// Normalize the diagonal to the input mode of bottom left and top right
|
||||
void CombinedNonMaxSuppressionCpuKernel::regular_input2buffer(float **boxes_buffer, float *box_src,
|
||||
const int class_idx) {
|
||||
/**
|
||||
* shape of box_src
|
||||
* box_src[num_boxes*q*4]
|
||||
* ways to visit box_src[i][class_idx][k] which stored by 1-dimension
|
||||
* box_src[i][class_idx][k]=box_src[i*q*4+class_idx*4+k]
|
||||
*/
|
||||
int box_len1;
|
||||
int sub_box_len1 = q * 4;
|
||||
int box_len2 = (class_idx << 2);
|
||||
for (int i = 0; i < num_boxes; i++) {
|
||||
box_len1 = i * sub_box_len1 + box_len2;
|
||||
if (box_src[box_len1] > box_src[box_len1 + 2]) {
|
||||
boxes_buffer[i][0] = box_src[box_len1 + 2];
|
||||
boxes_buffer[i][2] = box_src[box_len1 + 0];
|
||||
} else {
|
||||
boxes_buffer[i][0] = box_src[box_len1 + 0];
|
||||
boxes_buffer[i][2] = box_src[box_len1 + 2];
|
||||
}
|
||||
if (box_src[box_len1 + 1] > box_src[box_len1 + 3]) {
|
||||
boxes_buffer[i][1] = box_src[box_len1 + 3];
|
||||
boxes_buffer[i][3] = box_src[box_len1 + 1];
|
||||
} else {
|
||||
boxes_buffer[i][1] = box_src[box_len1 + 1];
|
||||
boxes_buffer[i][3] = box_src[box_len1 + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the area ratio of the intersection of two squares
|
||||
float CombinedNonMaxSuppressionCpuKernel::IOU(float **boxes_buffer, int i, int j) {
|
||||
const float *box_a = boxes_buffer[i];
|
||||
const float *box_b = boxes_buffer[j];
|
||||
float lx, ly, rx, ry;
|
||||
float w, h;
|
||||
float area;
|
||||
float area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
|
||||
float area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
|
||||
if (area_a <= 0 || area_b <= 0) {
|
||||
return 0.0;
|
||||
}
|
||||
lx = box_a[0] > box_b[0] ? box_a[0] : box_b[0];
|
||||
ly = box_a[1] > box_b[1] ? box_a[1] : box_b[1];
|
||||
rx = box_a[2] < box_b[2] ? box_a[2] : box_b[2];
|
||||
ry = box_a[3] < box_b[3] ? box_a[3] : box_b[3];
|
||||
w = rx > lx ? (rx - lx) : 0;
|
||||
h = ry > ly ? (ry - ly) : 0;
|
||||
area = w * h;
|
||||
return area / (area_a + area_b - area);
|
||||
}
|
||||
|
||||
/**
|
||||
* if soft_nms_sigma > 0.0, soft_nms is used, means update by score=score*exp(scale*iou^2)
|
||||
* if soft_nms_sigma <= 0.0, nms is used, means delete it when iou > iou_threshold
|
||||
* run non max suppression per bath per class
|
||||
*/
|
||||
void CombinedNonMaxSuppressionCpuKernel::non_max_suppression(float **boxes_buffer, float *scores_buffer,
|
||||
std::vector<int> &selected) {
|
||||
std::priority_queue<non_max_suppression_local::score_index> pq;
|
||||
for (int i = 0; i < num_boxes; i++) {
|
||||
if (scores_buffer[i] > score_threshold) {
|
||||
pq.push(non_max_suppression_local::score_index(i, scores_buffer[i], 0));
|
||||
}
|
||||
}
|
||||
|
||||
float scale = static_cast<float>(0.0);
|
||||
bool is_soft_nms = soft_nms_sigma > static_cast<float>(0.0);
|
||||
if (is_soft_nms) {
|
||||
scale = static_cast<float>(-0.5) / soft_nms_sigma;
|
||||
}
|
||||
|
||||
float similarity;
|
||||
float original_score;
|
||||
non_max_suppression_local::score_index next_si;
|
||||
|
||||
while (((int)selected.size() < size_per_class) && (!pq.empty())) {
|
||||
next_si = pq.top();
|
||||
original_score = next_si.score;
|
||||
pq.pop();
|
||||
bool should_hard_suppress = false;
|
||||
|
||||
for (int j = static_cast<int>(selected.size()) - 1; j >= next_si.suppress_begin_index; j--) {
|
||||
similarity = IOU(boxes_buffer, next_si.box_index, selected[j]);
|
||||
if (is_soft_nms) {
|
||||
next_si.score *=
|
||||
similarity <= iou_threshold ? std::exp(scale * similarity * similarity) : static_cast<float>(0.0);
|
||||
}
|
||||
if (!is_soft_nms && similarity > iou_threshold) {
|
||||
should_hard_suppress = true;
|
||||
break;
|
||||
}
|
||||
if (next_si.score <= score_threshold) break;
|
||||
}
|
||||
|
||||
next_si.suppress_begin_index = selected.size();
|
||||
if (!should_hard_suppress) {
|
||||
if (next_si.score == original_score) {
|
||||
selected.push_back(next_si.box_index);
|
||||
continue;
|
||||
}
|
||||
if (next_si.score > score_threshold) {
|
||||
pq.push(next_si);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CombinedNonMaxSuppressionCpuKernel::nms_perclass(
|
||||
float *boxes, float *scores, std::vector<non_max_suppression_local::result_para> &sub_result_vec, int &result_size) {
|
||||
int k = 0;
|
||||
int box_idx;
|
||||
int boxe_len1;
|
||||
int sub_box_len1 = q * 4;
|
||||
int box_len2 = 0;
|
||||
float **boxes_buffer = new float *[num_boxes]();
|
||||
float *scores_buffer = new float[num_boxes]();
|
||||
for (int i = 0; i < num_boxes; i++) {
|
||||
boxes_buffer[i] = new float[4];
|
||||
}
|
||||
/**
|
||||
* shape of score and boxes
|
||||
* score[num_boxes*num_class]
|
||||
* boxes[num_boxes*q*4]
|
||||
*/
|
||||
if (q == 1) {
|
||||
regular_input2buffer(boxes_buffer, boxes, 0);
|
||||
}
|
||||
for (int j = 0; j < num_class; j++) {
|
||||
for (int i = 0; i < num_boxes; i++) {
|
||||
scores_buffer[i] = scores[i * num_class + j];
|
||||
}
|
||||
if (q > 1) {
|
||||
regular_input2buffer(boxes_buffer, boxes, j);
|
||||
box_len2 = j * 4;
|
||||
}
|
||||
std::vector<int> selected;
|
||||
non_max_suppression(boxes_buffer, scores_buffer, selected);
|
||||
for (int i = 0; i < (int)selected.size(); i++) {
|
||||
box_idx = selected[i];
|
||||
boxe_len1 = box_idx * sub_box_len1 + box_len2;
|
||||
sub_result_vec[k++] = {box_idx,
|
||||
scores_buffer[box_idx],
|
||||
j,
|
||||
{boxes[boxe_len1 + 0], boxes[boxe_len1 + 1], boxes[boxe_len1 + 2], boxes[boxe_len1 + 3]}};
|
||||
}
|
||||
result_size += selected.size();
|
||||
}
|
||||
for (int i = 0; i < num_boxes; i++) {
|
||||
delete[] boxes_buffer[i];
|
||||
}
|
||||
delete[] boxes_buffer;
|
||||
delete[] scores_buffer;
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t CombinedNonMaxSuppressionCpuKernel::nms_perbath(CpuKernelContext &ctx, float *boxes, float *scores,
|
||||
float *nmsed_boxes, float *nmsed_scores, float *nmsed_class,
|
||||
int *valid_detection) {
|
||||
alloc_zeros(nmsed_boxes, num_bath * num_detection * 4);
|
||||
alloc_zeros(nmsed_scores, num_bath * num_detection);
|
||||
alloc_zeros(nmsed_class, num_bath * num_detection);
|
||||
alloc_zeros(valid_detection, num_bath);
|
||||
const float box_min = 0.0;
|
||||
const float box_max = 1.0;
|
||||
/**
|
||||
* shape of scores and boxes:
|
||||
* scores[num_bath*num_boxes*num_class]
|
||||
* boxes[num_bath*num_boxes*q*4]
|
||||
*/
|
||||
int score_len2 = num_boxes * num_class;
|
||||
int boxes_len2 = num_boxes * q * 4;
|
||||
auto shard_nms = [&](size_t start, size_t end) {
|
||||
for (int i = start; i < (int)end; i++) {
|
||||
int per_detections = 0;
|
||||
int scores_index = 0;
|
||||
int result_size = 0;
|
||||
std::vector<non_max_suppression_local::result_para> result_vec(size_per_class * num_class,
|
||||
{0, 0.0, 0, {0.0, 0.0, 0.0, 0.0}});
|
||||
nms_perclass(boxes + i * boxes_len2, scores + i * score_len2, result_vec, result_size);
|
||||
if (!pad_per_class) {
|
||||
per_detections = std::min(result_size, max_total_size);
|
||||
} else {
|
||||
per_detections = std::min(result_size, num_detection);
|
||||
}
|
||||
std::sort(result_vec.begin(), result_vec.begin() + result_size, non_max_suppression_local::result_cmp);
|
||||
scores_index = i * num_detection;
|
||||
for (int k = 0; k < per_detections; k++) {
|
||||
if (clip_boxes) {
|
||||
nmsed_boxes[(scores_index << 2) + 0] = std::max(std::min(result_vec[k].box_coord[0], box_max), box_min);
|
||||
nmsed_boxes[(scores_index << 2) + 1] = std::max(std::min(result_vec[k].box_coord[1], box_max), box_min);
|
||||
nmsed_boxes[(scores_index << 2) + 2] = std::max(std::min(result_vec[k].box_coord[2], box_max), box_min);
|
||||
nmsed_boxes[(scores_index << 2) + 3] = std::max(std::min(result_vec[k].box_coord[3], box_max), box_min);
|
||||
nmsed_scores[scores_index] = result_vec[k].score;
|
||||
nmsed_class[scores_index] = (float)result_vec[k].class_idx;
|
||||
} else {
|
||||
nmsed_boxes[(scores_index << 2) + 0] = result_vec[k].box_coord[0];
|
||||
nmsed_boxes[(scores_index << 2) + 1] = result_vec[k].box_coord[1];
|
||||
nmsed_boxes[(scores_index << 2) + 2] = result_vec[k].box_coord[2];
|
||||
nmsed_boxes[(scores_index << 2) + 3] = result_vec[k].box_coord[3];
|
||||
nmsed_scores[scores_index] = result_vec[k].score;
|
||||
nmsed_class[scores_index] = (float)result_vec[k].class_idx;
|
||||
}
|
||||
scores_index++;
|
||||
}
|
||||
valid_detection[i] = per_detections;
|
||||
}
|
||||
};
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (max_core_num > num_bath) {
|
||||
max_core_num = num_bath;
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_bath, num_bath / max_core_num, shard_nms),
|
||||
"CombinedNonMaxSuppression Compute failed in nms_perbath stage.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CombinedNonMaxSuppressionCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
|
||||
"CombinedNonMaxSuppression check input and output number failed.");
|
||||
KERNEL_HANDLE_ERROR(CombinedNonMaxSuppressionCheck(ctx), "CombinedNonMaxSuppression check params failed.");
|
||||
CombinedNonMaxSuppressionCompute(ctx);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCheck(CpuKernelContext &ctx) {
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(2)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 2 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(3)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 3 data failed.");
|
||||
if (ctx.Input(4) != nullptr) {
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(4)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 4 data failed.");
|
||||
}
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(5)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 5 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 0 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 1 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(2)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 2 data failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(3)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 3 data failed.");
|
||||
KERNEL_CHECK_FALSE((ctx.Input(0)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input0 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(0)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input1 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(1)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Input(2)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input2 [%s] must be [DT_INT32].", DTypeStr(ctx.Input(2)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Input(3)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input3 [%s] must be [DT_INT32].", DTypeStr(ctx.Input(3)->GetDataType()).c_str());
|
||||
if (ctx.Input(4) != NULL) {
|
||||
KERNEL_CHECK_FALSE((ctx.Input(4)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input4 [%s] must be [DT_FLOAT].",
|
||||
DTypeStr(ctx.Input(4)->GetDataType()).c_str());
|
||||
}
|
||||
KERNEL_CHECK_FALSE((ctx.Input(5)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input5 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(5)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Output(0)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of output0 [%s] must be [DT_FLOAT].",
|
||||
DTypeStr(ctx.Output(0)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Output(1)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of output1 [%s] must be [DT_FLOAT].",
|
||||
DTypeStr(ctx.Output(1)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Output(2)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of output2 [%s] must be [DT_FLOAT].",
|
||||
DTypeStr(ctx.Output(2)->GetDataType()).c_str());
|
||||
KERNEL_CHECK_FALSE((ctx.Output(3)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of output3 [%s] must be [DT_INT32].",
|
||||
DTypeStr(ctx.Output(3)->GetDataType()).c_str());
|
||||
auto input0_shape = ctx.Input(0)->GetTensorShape();
|
||||
auto input1_shape = ctx.Input(1)->GetTensorShape();
|
||||
auto input2_shape = ctx.Input(2)->GetTensorShape();
|
||||
auto input3_shape = ctx.Input(3)->GetTensorShape();
|
||||
auto input5_shape = ctx.Input(5)->GetTensorShape();
|
||||
KERNEL_CHECK_FALSE((input0_shape->GetDims() == 4), KERNEL_STATUS_PARAM_INVALID, "The input0's dims [%d] must be 4",
|
||||
input0_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE((input1_shape->GetDims() == 3), KERNEL_STATUS_PARAM_INVALID, "The input1's dims [%d] must be 3",
|
||||
input1_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE(
|
||||
(input2_shape->GetDims() == 0 || (input2_shape->GetDims() == 1 && input2_shape->GetDimSize(0) == 1)),
|
||||
KERNEL_STATUS_PARAM_INVALID, "The input2's dims [%d] must be 0 or 1x1", input2_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE(
|
||||
(input3_shape->GetDims() == 0 || (input3_shape->GetDims() == 1 && input3_shape->GetDimSize(0) == 1)),
|
||||
KERNEL_STATUS_PARAM_INVALID, "The input3's dims [%d] must be 0 or 1x1", input3_shape->GetDims());
|
||||
if (ctx.Input(4) != nullptr) {
|
||||
auto input4_shape = ctx.Input(4)->GetTensorShape();
|
||||
KERNEL_CHECK_FALSE(
|
||||
(input4_shape->GetDims() == 0 || (input4_shape->GetDims() == 1 && input4_shape->GetDimSize(0) == 1)),
|
||||
KERNEL_STATUS_PARAM_INVALID, "The input4's dims [%d] must be 0 or 1x1", input4_shape->GetDims());
|
||||
}
|
||||
KERNEL_CHECK_FALSE(
|
||||
(input5_shape->GetDims() == 0 || (input5_shape->GetDims() == 1 && input5_shape->GetDimSize(0) == 1)),
|
||||
KERNEL_STATUS_PARAM_INVALID, "The input5's dims [%d] must be 0 or 1x1", input5_shape->GetDims());
|
||||
auto output0_shape = ctx.Output(0)->GetTensorShape();
|
||||
auto output1_shape = ctx.Output(1)->GetTensorShape();
|
||||
auto output2_shape = ctx.Output(2)->GetTensorShape();
|
||||
auto output3_shape = ctx.Output(3)->GetTensorShape();
|
||||
KERNEL_CHECK_FALSE((output0_shape->GetDims() == 3), KERNEL_STATUS_PARAM_INVALID, "The output0's dims [%d] must be 3",
|
||||
output0_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE((output1_shape->GetDims() == 2), KERNEL_STATUS_PARAM_INVALID, "The output1's dims [%d] must be 2",
|
||||
output1_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE((output2_shape->GetDims() == 2), KERNEL_STATUS_PARAM_INVALID, "The output2's dims [%d] must be 2",
|
||||
output2_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE((output3_shape->GetDims() == 1), KERNEL_STATUS_PARAM_INVALID, "The output3's dims [%d] must be 1",
|
||||
output3_shape->GetDims());
|
||||
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(0) == input1_shape->GetDimSize(0)), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The input0's 1st dims [%d] need be same with the input1's 1st dims[%d]",
|
||||
input0_shape->GetDimSize(0), input1_shape->GetDimSize(0));
|
||||
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(1) == input1_shape->GetDimSize(1)), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The input0's 2nd dims [%d] need be same with the input1's 2nd dims[%d]",
|
||||
input0_shape->GetDimSize(1), input1_shape->GetDimSize(1));
|
||||
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(2) == input1_shape->GetDimSize(2) || input0_shape->GetDimSize(2) == 1),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"The input0's 3th dims [%d] need be same with the input1's 3th dims [%d] or 1",
|
||||
input0_shape->GetDimSize(2), output1_shape->GetDimSize(2));
|
||||
KERNEL_CHECK_FALSE((input0_shape->GetDimSize(3) == 4), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The input0's 4th dims [%d] need be same with 4", input0_shape->GetDimSize(1));
|
||||
KERNEL_CHECK_FALSE((output0_shape->GetDimSize(0) == output1_shape->GetDimSize(0) &&
|
||||
output0_shape->GetDimSize(0) == output2_shape->GetDimSize(0) &&
|
||||
output0_shape->GetDimSize(0) == output3_shape->GetDimSize(0)),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"The input0's 1st dims [%d], input1's 1st dims [%d],"
|
||||
" input2's 1st dims [%d], input3's 1st dims [%d], need be same with each other",
|
||||
output0_shape->GetDimSize(0), output1_shape->GetDimSize(0), output2_shape->GetDimSize(0),
|
||||
output3_shape->GetDimSize(0));
|
||||
KERNEL_CHECK_FALSE((output0_shape->GetDimSize(1) == output1_shape->GetDimSize(1) &&
|
||||
output0_shape->GetDimSize(1) == output2_shape->GetDimSize(1)),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"The input0's 2nd dims [%d], input1's 2nd dims [%d], input2's 2nd dims [%d],"
|
||||
" need be same with each other",
|
||||
output0_shape->GetDimSize(1), output1_shape->GetDimSize(1), output2_shape->GetDimSize(1));
|
||||
KERNEL_LOG_INFO(
|
||||
" CombinedNonMaxSuppressionCpuKernel[%s], input0: size[%llu], "
|
||||
" input1: size[%llu]",
|
||||
ctx.GetOpType().c_str(), ctx.Input(0)->GetDataSize(), ctx.Input(1)->GetDataSize());
|
||||
KERNEL_LOG_INFO(
|
||||
" output0: size[%llu], output1: size[%llu],"
|
||||
" output2: size[%llu], output3: size[%llu].",
|
||||
ctx.Output(0)->GetDataSize(), ctx.Output(1)->GetDataSize(), ctx.Output(2)->GetDataSize(),
|
||||
ctx.Output(3)->GetDataSize());
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCompute(CpuKernelContext &ctx) {
|
||||
float *boxes = reinterpret_cast<float *>(ctx.Input(0)->GetData());
|
||||
float *scores = reinterpret_cast<float *>(ctx.Input(1)->GetData());
|
||||
max_output_size_per_class = *(reinterpret_cast<int *>(ctx.Input(2)->GetData()));
|
||||
max_total_size = *(reinterpret_cast<int *>(ctx.Input(3)->GetData()));
|
||||
iou_threshold = *(reinterpret_cast<float *>(ctx.Input(4)->GetData()));
|
||||
score_threshold = *(reinterpret_cast<float *>(ctx.Input(5)->GetData()));
|
||||
num_bath = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(0));
|
||||
num_boxes = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(1));
|
||||
q = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(2));
|
||||
num_class = (int)(ctx.Input(1)->GetTensorShape()->GetDimSize(2));
|
||||
pad_per_class = false;
|
||||
clip_boxes = true;
|
||||
if (ctx.GetAttr("pad_per_class") != nullptr) {
|
||||
pad_per_class = (bool)(ctx.GetAttr("pad_per_class")->GetBool());
|
||||
}
|
||||
if (ctx.GetAttr("clip_boxes") != nullptr) {
|
||||
clip_boxes = (bool)(ctx.GetAttr("clip_boxes")->GetBool());
|
||||
}
|
||||
float *nmsed_boxes = reinterpret_cast<float *>(ctx.Output(0)->GetData());
|
||||
float *nmsed_scores = reinterpret_cast<float *>(ctx.Output(1)->GetData());
|
||||
float *nmsed_class = reinterpret_cast<float *>(ctx.Output(2)->GetData());
|
||||
int *valid_detection = reinterpret_cast<int *>(ctx.Output(3)->GetData());
|
||||
auto output0_shape = ctx.Output(0)->GetTensorShape();
|
||||
auto output1_shape = ctx.Output(1)->GetTensorShape();
|
||||
auto output2_shape = ctx.Output(2)->GetTensorShape();
|
||||
auto output3_shape = ctx.Output(3)->GetTensorShape();
|
||||
size_per_class = max_output_size_per_class < num_boxes ? max_output_size_per_class : num_boxes;
|
||||
soft_nms_sigma = 0.0;
|
||||
if (pad_per_class) {
|
||||
num_detection = std::min(max_total_size, max_output_size_per_class * num_class);
|
||||
} else {
|
||||
num_detection = max_total_size;
|
||||
}
|
||||
KERNEL_CHECK_FALSE((output0_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output0's 1nd dims [%d] must be [%d]", output0_shape->GetDimSize(0), num_bath);
|
||||
KERNEL_CHECK_FALSE((output1_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output0's 1nd dims [%d] must be [%d]", output1_shape->GetDimSize(0), num_bath);
|
||||
KERNEL_CHECK_FALSE((output2_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output0's 1nd dims [%d] must be [%d]", output2_shape->GetDimSize(0), num_bath);
|
||||
KERNEL_CHECK_FALSE((output3_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output0's 1nd dims [%d] must be [%d]", output3_shape->GetDimSize(0), num_bath);
|
||||
KERNEL_CHECK_FALSE((max_output_size_per_class > 0), KERNEL_STATUS_PARAM_INVALID,
|
||||
"max_output_size_per_class [%d] must be > 0", max_output_size_per_class);
|
||||
KERNEL_CHECK_FALSE((max_total_size > 0), KERNEL_STATUS_PARAM_INVALID, "max_total_size [%d] must be > 0",
|
||||
max_total_size);
|
||||
KERNEL_CHECK_FALSE((iou_threshold >= 0 && iou_threshold <= 1), KERNEL_STATUS_PARAM_INVALID,
|
||||
"iou_threshold [%f] must be in [0,1]", iou_threshold);
|
||||
KERNEL_CHECK_FALSE(((int)output0_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output0's 2nd dims [%d] need be same with %d", output0_shape->GetDimSize(1), num_detection);
|
||||
KERNEL_CHECK_FALSE(((int)output1_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output1's 2nd dims [%d] need be same with %d", output1_shape->GetDimSize(1), num_detection);
|
||||
KERNEL_CHECK_FALSE(((int)output2_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output2's 2nd dims [%d] need be same with %d", output2_shape->GetDimSize(1), num_detection);
|
||||
nms_perbath(ctx, boxes, scores, nmsed_boxes, nmsed_scores, nmsed_class, valid_detection);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kCombinedNonMaxSuppression, CombinedNonMaxSuppressionCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_COMBINEDNONMAXSUPPRESSION_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_COMBINEDNONMAXSUPPRESSION_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace non_max_suppression_local {
|
||||
struct score_index {
|
||||
int box_index;
|
||||
float score;
|
||||
int suppress_begin_index;
|
||||
score_index() {}
|
||||
score_index(int bi, float s, int sbi) : box_index(bi), score(s), suppress_begin_index(sbi) {}
|
||||
bool operator<(const score_index &b) const {
|
||||
return (score < b.score) || ((score == b.score) && (box_index > b.box_index));
|
||||
}
|
||||
};
|
||||
|
||||
struct result_para {
|
||||
int box_index;
|
||||
float score;
|
||||
int class_idx;
|
||||
float box_coord[4];
|
||||
};
|
||||
|
||||
bool result_cmp(result_para &a, result_para &b) { return a.score > b.score; }
|
||||
} // namespace non_max_suppression_local
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class CombinedNonMaxSuppressionCpuKernel : public CpuKernel {
|
||||
public:
|
||||
CombinedNonMaxSuppressionCpuKernel() = default;
|
||||
~CombinedNonMaxSuppressionCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t CombinedNonMaxSuppressionCheck(CpuKernelContext &ctx);
|
||||
|
||||
uint32_t CombinedNonMaxSuppressionCompute(CpuKernelContext &ctx);
|
||||
uint32_t nms_perbath(CpuKernelContext &, float *, float *, float *, float *, float *, int *);
|
||||
void regular_input2buffer(float **, float *, const int);
|
||||
float IOU(float **, int, int);
|
||||
void non_max_suppression(float **, float *, std::vector<int> &);
|
||||
void nms_perclass(float *, float *, std::vector<non_max_suppression_local::result_para> &, int &);
|
||||
int num_bath;
|
||||
int num_boxes;
|
||||
int q;
|
||||
int num_class;
|
||||
// per batch size
|
||||
int num_detection;
|
||||
int max_total_size;
|
||||
// The length of each type of selection defined by the user
|
||||
int max_output_size_per_class;
|
||||
// Calculation num_detection length
|
||||
int size_per_class;
|
||||
// When lower than a score_threshold, delete the relevant box
|
||||
float score_threshold;
|
||||
// When it is higher than the threshold value, according to the soft_nms_sigma determines deletion or decay
|
||||
float iou_threshold;
|
||||
float soft_nms_sigma;
|
||||
bool pad_per_class;
|
||||
bool clip_boxes;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "complex.h"
|
||||
#include "Eigen/Eigen"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 2;
|
||||
const char *kComplex = "Complex";
|
||||
constexpr int64_t kFloatMaxNums = 8 * 128 * 1024;
|
||||
constexpr int64_t kDoubleMaxNums = 16 * 128 * 1024;
|
||||
#define Complex_COMPUTE_CASE(IN_DTYPE, IN_TYPE, OUT_DTYPE, CTX) \
|
||||
case (IN_DTYPE): { \
|
||||
switch (OUT_DTYPE) { \
|
||||
case (DT_COMPLEX64): { \
|
||||
uint32_t result = ComplexCompute<float, std::complex<float>>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Complex kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case (DT_COMPLEX128): { \
|
||||
uint32_t result = ComplexCompute<double, std::complex<double>>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Complex kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
KERNEL_LOG_ERROR("Complex kernel output data type [%s] not support.", DTypeStr(OUT_DTYPE).c_str()); \
|
||||
return KERNEL_STATUS_PARAM_INVALID; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ComplexCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kComplex);
|
||||
DataType input_type = ctx.Input(0)->GetDataType();
|
||||
switch (input_type) {
|
||||
Complex_COMPUTE_CASE(DT_FLOAT, float, DT_COMPLEX64, ctx)
|
||||
Complex_COMPUTE_CASE(DT_DOUBLE, double, DT_COMPLEX128, ctx) default
|
||||
: KERNEL_LOG_ERROR("Complex kernel input data type [%s] not support.", DTypeStr(input_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T, typename t>
|
||||
uint32_t ComplexCpuKernel::ComplexCompute(CpuKernelContext &ctx) {
|
||||
auto input0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto input1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto output = reinterpret_cast<t *>(ctx.Output(0)->GetData());
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
int64_t data_num = ctx.Output(0)->NumElements();
|
||||
int64_t data_size = data_num * sizeof(T);
|
||||
if ((data_type == DT_FLOAT && data_size <= kFloatMaxNums) ||
|
||||
(data_type == DT_DOUBLE && data_size <= kDoubleMaxNums)) {
|
||||
for (int64_t index = 0; index < data_num; ++index) {
|
||||
*(output + index) = t(*(input0 + index), *(input1 + index));
|
||||
}
|
||||
} else {
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (max_core_num > data_num) {
|
||||
max_core_num = data_num;
|
||||
}
|
||||
auto shard_complex = [&](size_t start, size_t end) {
|
||||
for (size_t index = start; index < end; ++index) {
|
||||
*(output + index) = t(*(input0 + index), *(input1 + index));
|
||||
}
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_complex),
|
||||
"complex Compute failed");
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kComplex, ComplexCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright (C) 2020-2021. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
|
||||
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* Apache License for more details at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* @brief
|
||||
*
|
||||
* @version 1.0
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_COMPLEX_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_COMPLEX_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ComplexCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ComplexCpuKernel() = default;
|
||||
~ComplexCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ComplexCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T, typename t>
|
||||
uint32_t ComplexCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2021 Jilin University
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "complex_abs.h"
|
||||
|
||||
#include <unsupported/Eigen/CXX11/Tensor>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const std::uint32_t kComplexAbsInputNum{1u};
|
||||
const std::uint32_t kComplexAbsOutputNum{1u};
|
||||
const char *kComplexAbs{"ComplexAbs"};
|
||||
const std::int64_t kComplexAbsParallelNum{32 * 1024};
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
inline const typename T::value_type ScalarComplexAbs(const T &x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
inline std::uint32_t ParallelForComplexAbs(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size,
|
||||
const std::function<void(std::int64_t, std::int64_t)> &work) {
|
||||
if (total > kComplexAbsParallelNum)
|
||||
return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work);
|
||||
else
|
||||
work(0, total);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
inline std::uint32_t ComputeComplexAbsKernel(const CpuKernelContext &ctx) {
|
||||
T *input0{static_cast<T *>(ctx.Input(0)->GetData())};
|
||||
typename T::value_type *output{static_cast<typename T::value_type *>(ctx.Output(0)->GetData())};
|
||||
std::int64_t total{ctx.Input(0)->NumElements()};
|
||||
std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)};
|
||||
std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)};
|
||||
return ParallelForComplexAbs(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) {
|
||||
std::transform(input0 + begin, input0 + end, output + begin, ScalarComplexAbs<T>);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::uint32_t ComputeComplexAbs(const CpuKernelContext &ctx) {
|
||||
std::uint32_t result{ComputeComplexAbsKernel<T>(ctx)};
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("ComplexAbs compute failed.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline std::uint32_t ExtraCheckComplexAbs(const CpuKernelContext &ctx) {
|
||||
if (ctx.Input(0)->GetDataType() == ctx.Output(0)->GetDataType()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The data type of the input [%s] should not be the same as the output "
|
||||
"[%s].",
|
||||
DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize() * 2) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The data size of the input [%llu] need be as twice as the output "
|
||||
"[%llu].",
|
||||
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
inline std::uint32_t CheckComplexAbs(CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) {
|
||||
return NormalCheck(ctx, kComplexAbsInputNum, kComplexAbsOutputNum) ? KERNEL_STATUS_PARAM_INVALID
|
||||
: ExtraCheckComplexAbs(ctx);
|
||||
}
|
||||
|
||||
inline std::uint32_t ComputeComplexAbs(const CpuKernelContext &ctx) {
|
||||
DataType input_type{ctx.Input(0)->GetDataType()};
|
||||
switch (input_type) {
|
||||
case DT_COMPLEX64:
|
||||
return ComputeComplexAbs<std::complex<std::float_t>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return ComputeComplexAbs<std::complex<std::double_t>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::uint32_t ComplexAbsCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
return detail::CheckComplexAbs(ctx, kComplexAbsInputNum, kComplexAbsOutputNum) ? KERNEL_STATUS_PARAM_INVALID
|
||||
: detail::ComputeComplexAbs(ctx);
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kComplexAbs, ComplexAbsCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Jilin University
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_COMPLEX_ABS_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_COMPLEX_ABS_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ComplexAbsCpuKernel final : public CpuKernel {
|
||||
virtual std::uint32_t Compute(CpuKernelContext &ctx) override final;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,303 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "concat.h"
|
||||
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
const char *const Concat = "Concat";
|
||||
}
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ConcatCpuKernel::CheckAndInitParams(CpuKernelContext &ctx) {
|
||||
if (ctx.GetAttr("N") == nullptr) {
|
||||
n_ = 1;
|
||||
} else {
|
||||
AttrValue *n_ptr = ctx.GetAttr("N");
|
||||
n_ = n_ptr->GetInt();
|
||||
}
|
||||
// "x" is a list of at least 2 "tensor" objects of the same type
|
||||
KERNEL_CHECK_FALSE((n_ >= 2), KERNEL_STATUS_PARAM_INVALID, "Attr N must >= 2, but got attr N[%lld]", n_);
|
||||
|
||||
uint32_t input_num = ctx.GetInputsSize();
|
||||
|
||||
// input_num is n_(concat tensor num) + 1(concat_dim)
|
||||
KERNEL_CHECK_FALSE((static_cast<int64_t>(input_num) - 1 == n_), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input num must equal attr N[%lld + 1],"
|
||||
"but got input num[%u]",
|
||||
n_, input_num);
|
||||
|
||||
Tensor *concat_dim_ptr = ctx.Input(0);
|
||||
KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim failed.");
|
||||
auto concat_dim_shape_ptr = concat_dim_ptr->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim shape failed.");
|
||||
int32_t concat_dim_dims = concat_dim_shape_ptr->GetDims();
|
||||
KERNEL_CHECK_FALSE((concat_dim_dims == 0) || ((concat_dim_dims == 1) && (concat_dim_shape_ptr->NumElements() == 1)),
|
||||
KERNEL_STATUS_PARAM_INVALID, "Input concat_dim should be a scalar integer, but got rank[%d].",
|
||||
concat_dim_dims);
|
||||
int32_t concat_dim = 0;
|
||||
DataType concat_dim_data_type = concat_dim_ptr->GetDataType();
|
||||
KERNEL_CHECK_FALSE((concat_dim_data_type == DT_INT32 || concat_dim_data_type == DT_INT64),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input concat_dim data type must DT_INT32 or DT_INT64,"
|
||||
"but got data type[%d].",
|
||||
DTypeStr(concat_dim_data_type).c_str());
|
||||
auto concat_dim_data_ptr = concat_dim_ptr->GetData();
|
||||
KERNEL_CHECK_NULLPTR(concat_dim_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim data failed.");
|
||||
if (concat_dim_data_type == DT_INT32) {
|
||||
concat_dim = static_cast<int64_t>(*reinterpret_cast<int32_t *>(concat_dim_data_ptr));
|
||||
} else {
|
||||
concat_dim = *reinterpret_cast<int64_t *>(concat_dim_data_ptr);
|
||||
}
|
||||
|
||||
Tensor *input0_ptr = ctx.Input(1);
|
||||
auto input0_type_ptr = input0_ptr->GetDataType();
|
||||
for (int64_t i = 1; i < n_; i++) {
|
||||
Tensor *inputi_ptr = ctx.Input(i);
|
||||
auto inputi_shape_ptr = inputi_ptr->GetTensorShape();
|
||||
auto inputi_type_ptr = inputi_ptr->GetDataType();
|
||||
KERNEL_CHECK_NULLPTR(inputi_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input xi failed.");
|
||||
KERNEL_CHECK_NULLPTR(inputi_shape_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input xi shape failed.");
|
||||
KERNEL_CHECK_FALSE((input0_type_ptr == inputi_type_ptr), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input tensor should have same type"
|
||||
"but got %d and %d.",
|
||||
DTypeStr(input0_type_ptr).c_str(), DTypeStr(inputi_type_ptr).c_str());
|
||||
}
|
||||
auto input0_shape_ptr = input0_ptr->GetTensorShape();
|
||||
input_dims_ = input0_shape_ptr->GetDims();
|
||||
data_type_ = input0_ptr->GetDataType();
|
||||
axis_ = concat_dim < 0 ? concat_dim + input_dims_ : concat_dim;
|
||||
KERNEL_CHECK_FALSE((0 <= axis_ && axis_ < input_dims_), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input concat_dim need in the "
|
||||
"range[%d, %d), but got %lld.",
|
||||
-input_dims_, input_dims_, concat_dim);
|
||||
inputs_flat_dim0_ = 1;
|
||||
for (uint32_t d = 0; d < axis_; ++d) {
|
||||
inputs_flat_dim0_ *= input0_shape_ptr->GetDimSize(d);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ConcatCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_CHECK_FALSE((CheckAndInitParams(ctx) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID,
|
||||
"CheckAndInitParams failed.");
|
||||
switch (data_type_) {
|
||||
case DT_FLOAT16:
|
||||
return DoCompute<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DoCompute<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DoCompute<double>(ctx);
|
||||
case DT_INT8:
|
||||
return DoCompute<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DoCompute<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DoCompute<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DoCompute<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DoCompute<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DoCompute<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DoCompute<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DoCompute<uint64_t>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DoCompute<complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DoCompute<complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("unsupported datatype[%d]", DTypeStr(data_type_).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConcatCpuKernel::PrepareInput(CpuKernelContext &ctx,
|
||||
std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs) {
|
||||
inputs.reserve(n_);
|
||||
output_concat_dim_ = 0;
|
||||
auto input0_shape_ptr = ctx.Input(1)->GetTensorShape();
|
||||
for (uint32_t i = 0; i < n_; ++i) {
|
||||
Tensor *input_i_ptr = ctx.Input(i + 1);
|
||||
int64_t input_i_num = input_i_ptr->NumElements();
|
||||
if (input_i_num == 0) {
|
||||
continue;
|
||||
}
|
||||
auto input_i_shape_ptr = input_i_ptr->GetTensorShape();
|
||||
int32_t input_i_dims = input_i_shape_ptr->GetDims();
|
||||
KERNEL_CHECK_FALSE((input_i_dims == input_dims_), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Ranks of inputs should match: shape[0]=%d vs. shape[%u]=%d", input_dims_, i, input_i_dims);
|
||||
for (int32_t j = 0; j < input_dims_; ++j) {
|
||||
int64_t dim_ij = input_i_shape_ptr->GetDimSize(j);
|
||||
if (j == axis_) {
|
||||
output_concat_dim_ += input_i_dims > 0 ? dim_ij : 1;
|
||||
continue;
|
||||
}
|
||||
int64_t dim_0j = input0_shape_ptr->GetDimSize(j);
|
||||
KERNEL_CHECK_FALSE((dim_0j == dim_ij), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Dimensions of inputs should match: shape[0][%d]=%lld vs."
|
||||
"shape[%u][%d]=%lld",
|
||||
j, dim_0j, i, j, dim_ij);
|
||||
}
|
||||
|
||||
int64_t inputs_flat_dim1 = input_i_num / inputs_flat_dim0_;
|
||||
auto input_i_data_ptr = input_i_ptr->GetData();
|
||||
KERNEL_CHECK_NULLPTR(input_i_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input x%u data failed.", i);
|
||||
auto input_i = std::make_shared<typename TTypes<T>::ConstMatrix>(reinterpret_cast<T *>(input_i_data_ptr),
|
||||
inputs_flat_dim0_, inputs_flat_dim1);
|
||||
KERNEL_CHECK_NULLPTR(input_i, KERNEL_STATUS_PARAM_INVALID, "Create input x%u failed!", i);
|
||||
inputs.emplace_back(std::move(input_i));
|
||||
}
|
||||
|
||||
if (input_dims_ == 0) {
|
||||
output_concat_dim_ = n_;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConcatCpuKernel::PrepareOutput(CpuKernelContext &ctx, std::shared_ptr<typename TTypes<T>::Matrix> &output) {
|
||||
Tensor *output_ptr = ctx.Output(0);
|
||||
KERNEL_CHECK_NULLPTR(output_ptr, KERNEL_STATUS_PARAM_INVALID, "Get output failed.");
|
||||
auto output_data_ptr = output_ptr->GetData();
|
||||
KERNEL_CHECK_NULLPTR(output_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get output data failed.");
|
||||
int64_t output_num = output_ptr->NumElements();
|
||||
int64_t output_dim1 = output_num / inputs_flat_dim0_;
|
||||
output = std::make_shared<typename TTypes<T>::Matrix>(reinterpret_cast<T *>(output_data_ptr), inputs_flat_dim0_,
|
||||
output_dim1);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Create output matrix failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConcatCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> inputs;
|
||||
KERNEL_CHECK_FALSE((PrepareInput<T>(ctx, inputs) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID,
|
||||
"PrepareInput failed.");
|
||||
std::shared_ptr<typename TTypes<T>::Matrix> output = nullptr;
|
||||
KERNEL_CHECK_FALSE((PrepareOutput<T>(ctx, output) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID,
|
||||
"PrepareOutput failed.");
|
||||
if (inputs.size() > 0) {
|
||||
return ConcatCompute<T>(ctx, inputs, output);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConcatCpuKernel::ConcatCompute(CpuKernelContext &ctx,
|
||||
const std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs,
|
||||
std::shared_ptr<typename TTypes<T>::Matrix> &output) {
|
||||
size_t num_inputs = inputs.size();
|
||||
std::vector<ptrdiff_t> sizes;
|
||||
sizes.reserve(num_inputs);
|
||||
int64_t row_size = 0;
|
||||
for (const auto &input : inputs) {
|
||||
sizes.push_back(input->dimension(1));
|
||||
row_size += sizes.back();
|
||||
}
|
||||
uint32_t ret = KERNEL_STATUS_OK;
|
||||
auto work = [&row_size, &sizes, &inputs, &output, &num_inputs, &ret](int64_t start, int64_t end) {
|
||||
if (row_size == 0) {
|
||||
ret = KERNEL_STATUS_PARAM_INVALID;
|
||||
return;
|
||||
}
|
||||
int64_t skipped_rows = start / row_size;
|
||||
T *out = output->data() + skipped_rows * row_size;
|
||||
T *out_start = output->data() + start;
|
||||
T *out_end = output->data() + end;
|
||||
|
||||
// Handle partial row at start
|
||||
if (out < out_start) {
|
||||
for (size_t j = 0; j < num_inputs; ++j) {
|
||||
ptrdiff_t size = sizes[j];
|
||||
ptrdiff_t offset = out_start - out;
|
||||
if (size <= offset) {
|
||||
out += size;
|
||||
continue;
|
||||
}
|
||||
const T *inp = &(*inputs[j])(skipped_rows, 0);
|
||||
if (offset > 0) {
|
||||
out += offset;
|
||||
inp += offset;
|
||||
size -= offset;
|
||||
}
|
||||
size = std::min(size, out_end - out);
|
||||
KERNEL_CHECK_FALSE_EXEC((size > 0), break)
|
||||
size_t copy_size = size * sizeof(T);
|
||||
error_t ret = memcpy_s(out, copy_size, inp, copy_size);
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("Memcpy failed.");
|
||||
ret = KERNEL_STATUS_INNER_ERROR;
|
||||
return;
|
||||
}
|
||||
out += size;
|
||||
}
|
||||
++skipped_rows;
|
||||
}
|
||||
if (out < out_start || out > out_end) {
|
||||
KERNEL_LOG_ERROR("Out[%llx] not in range[%llx, %llx)", out, out_start, out_end);
|
||||
ret = KERNEL_STATUS_INNER_ERROR;
|
||||
return;
|
||||
}
|
||||
// Copy remaining data.
|
||||
std::vector<const T *> inp;
|
||||
inp.reserve(num_inputs);
|
||||
for (const auto &input : inputs) {
|
||||
inp.push_back(&(*input)(skipped_rows, 0));
|
||||
}
|
||||
const int64_t dim0 = output->dimension(0);
|
||||
for (int64_t i = skipped_rows; i < dim0; ++i) {
|
||||
for (int64_t j = 0; j < static_cast<int64_t>(num_inputs); ++j) {
|
||||
ptrdiff_t size = std::min(sizes[j], out_end - out);
|
||||
size_t copy_size = size * sizeof(T);
|
||||
auto ret = memcpy_s(out, copy_size, inp[j], copy_size);
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("Memcpy size[%zu] from inp[%llx] to out[%llx] failed.", copy_size, inp[j], out);
|
||||
ret = KERNEL_STATUS_INNER_ERROR;
|
||||
return;
|
||||
}
|
||||
out += size;
|
||||
inp[j] += size;
|
||||
KERNEL_CHECK_FALSE_EXEC((out != out_end), return );
|
||||
}
|
||||
}
|
||||
};
|
||||
const int64_t kParallelDataNumSameShapeBig = 255 * 1024;
|
||||
int64_t data_num = output->size();
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
|
||||
if (data_num >= kParallelDataNumSameShapeBig) {
|
||||
max_core_num = std::min(max_core_num, 6U);
|
||||
} else {
|
||||
max_core_num = std::min(max_core_num, 1U);
|
||||
}
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0.");
|
||||
}
|
||||
CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, work);
|
||||
KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), KERNEL_STATUS_INNER_ERROR, "ConcatCpuKernel failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(Concat, ConcatCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CONCAT_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CONCAT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "kernel_log.h"
|
||||
#include "securec.h"
|
||||
#include "status.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace aicpu {
|
||||
const uint32_t NumIndices = 2;
|
||||
template <typename T>
|
||||
struct TTypes {
|
||||
// Rank-2 tensor (matrix) of scalar type T.
|
||||
using Matrix = Eigen::TensorMap<Eigen::Tensor<T, NumIndices, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>;
|
||||
|
||||
using ConstMatrix =
|
||||
Eigen::TensorMap<Eigen::Tensor<const T, NumIndices, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>;
|
||||
};
|
||||
|
||||
class ConcatCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ConcatCpuKernel()
|
||||
: data_type_(DT_DOUBLE), input_dims_(0), n_(0), output_concat_dim_(0), axis_(0), inputs_flat_dim0_(0) {}
|
||||
|
||||
~ConcatCpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t CheckAndInitParams(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t PrepareInput(CpuKernelContext &ctx, std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs);
|
||||
|
||||
template <typename T>
|
||||
uint32_t PrepareOutput(CpuKernelContext &ctx, std::shared_ptr<typename TTypes<T>::Matrix> &output);
|
||||
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConcatCompute(CpuKernelContext &ctx,
|
||||
const std::vector<std::shared_ptr<typename TTypes<T>::ConstMatrix>> &inputs,
|
||||
std::shared_ptr<typename TTypes<T>::Matrix> &output);
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
int32_t input_dims_;
|
||||
int64_t n_;
|
||||
int64_t output_concat_dim_;
|
||||
int64_t axis_;
|
||||
int64_t inputs_flat_dim0_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2021 Jilin University
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cos.h"
|
||||
|
||||
#include <unsupported/Eigen/CXX11/Tensor>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const std::uint32_t kCosInputNum{1u};
|
||||
const std::uint32_t kCosOutputNum{1u};
|
||||
const char *kCos{"Cos"};
|
||||
const std::int64_t kCosParallelNum{64 * 1024};
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
inline T ScalarCos(T x) {
|
||||
return std::cos(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Eigen::half ScalarCos(Eigen::half x) {
|
||||
const Eigen::half val{static_cast<Eigen::half>(std::cos(static_cast<std::float_t>(x)))};
|
||||
return Eigen::half_impl::isnan(val) ? Eigen::half{0.0f} : val;
|
||||
}
|
||||
|
||||
inline std::uint32_t ParallelForCos(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size,
|
||||
const std::function<void(std::int64_t, std::int64_t)> &work) {
|
||||
if (total > kCosParallelNum)
|
||||
return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work);
|
||||
else
|
||||
work(0, total);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::uint32_t ComputeCosKernel(const CpuKernelContext &ctx) {
|
||||
T *input0{static_cast<T *>(ctx.Input(0)->GetData())};
|
||||
T *output{static_cast<T *>(ctx.Output(0)->GetData())};
|
||||
std::int64_t total{ctx.Input(0)->NumElements()};
|
||||
std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)};
|
||||
std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)};
|
||||
return ParallelForCos(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) {
|
||||
std::transform(input0 + begin, input0 + end, output + begin, ScalarCos<T>);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::uint32_t ComputeCos(const CpuKernelContext &ctx) {
|
||||
std::uint32_t result{ComputeCosKernel<T>(ctx)};
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Cos compute failed.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline std::uint32_t ExtraCheckCos(const CpuKernelContext &ctx) {
|
||||
if (ctx.Input(0)->GetDataType() != ctx.Output(0)->GetDataType()) {
|
||||
KERNEL_LOG_ERROR("The data type of the input [%s] need be the same as the output [%s].",
|
||||
DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The data size of the input [%llu] need be the same as the output "
|
||||
"[%llu].",
|
||||
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
inline std::uint32_t CheckCos(CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) {
|
||||
return NormalCheck(ctx, kCosInputNum, kCosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : ExtraCheckCos(ctx);
|
||||
}
|
||||
|
||||
inline std::uint32_t ComputeCos(const CpuKernelContext &ctx) {
|
||||
DataType input_type{ctx.Input(0)->GetDataType()};
|
||||
switch (input_type) {
|
||||
case DT_FLOAT16:
|
||||
return ComputeCos<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return ComputeCos<std::float_t>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return ComputeCos<std::double_t>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return ComputeCos<std::complex<std::float_t>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return ComputeCos<std::complex<std::double_t>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::uint32_t CosCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
return detail::CheckCos(ctx, kCosInputNum, kCosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : detail::ComputeCos(ctx);
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kCos, CosCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Jilin University
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_COS_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_COS_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CosCpuKernel final : public CpuKernel {
|
||||
virtual std::uint32_t Compute(CpuKernelContext &ctx) override final;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,272 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "count_nonzero.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kCountNonZeroInputNum = 1;
|
||||
const uint32_t kCountNonZeroOutputNum = 1;
|
||||
const int64_t kParallelNum = 2 * 1024;
|
||||
const char *kCountNonZero = "CountNonZero";
|
||||
|
||||
// The following code is used to handle the general case.
|
||||
// Params to use in ParallelIterator construction.
|
||||
std::vector<int64_t> cnz_dims;
|
||||
std::vector<int64_t> cnz_transposed_shape;
|
||||
int64_t cnz_stride;
|
||||
|
||||
// Class def of ParallelIterator.
|
||||
class ParallelIterator {
|
||||
public:
|
||||
ParallelIterator(std::vector<int64_t> transposed_shape, std::vector<int64_t> dims,
|
||||
const std::vector<int64_t> &input_shape);
|
||||
~ParallelIterator() = default;
|
||||
void Next();
|
||||
void Set(int64_t pos);
|
||||
inline int64_t Get() const { return _pos; };
|
||||
|
||||
private:
|
||||
int64_t _dimension{0};
|
||||
std::vector<int64_t> _coord;
|
||||
std::vector<int64_t> _shape;
|
||||
std::vector<int64_t> _strides;
|
||||
std::vector<int64_t> _back_strides;
|
||||
std::vector<int64_t> _dims;
|
||||
int64_t _pos{0};
|
||||
};
|
||||
|
||||
ParallelIterator::ParallelIterator(std::vector<int64_t> transposed_shape, std::vector<int64_t> dims,
|
||||
const std::vector<int64_t> &input_shape)
|
||||
: _dimension(transposed_shape.size()),
|
||||
_coord(transposed_shape.size(), 0),
|
||||
_shape(transposed_shape),
|
||||
_strides(transposed_shape.size(), 1),
|
||||
_back_strides(transposed_shape.size(), 1),
|
||||
_dims(dims),
|
||||
_pos(0) {
|
||||
std::vector<int64_t> strides(_dimension, 1);
|
||||
for (int64_t i = _dimension - 2; i >= 0; --i) {
|
||||
strides[i] = strides[i + 1] * input_shape[i + 1];
|
||||
}
|
||||
for (int64_t i = _dimension - 1; i >= 0; --i) {
|
||||
_strides[i] = strides[_dims[i]];
|
||||
_back_strides[i] = (_shape[i] - 1) * _strides[i];
|
||||
}
|
||||
}
|
||||
void ParallelIterator::Set(int64_t pos) {
|
||||
for (int64_t i = _dimension - 1; i >= 0 && pos != 0; --i) {
|
||||
_coord[i] = pos % _shape[i];
|
||||
_pos += _coord[i] * _strides[i];
|
||||
pos /= _shape[i];
|
||||
}
|
||||
}
|
||||
void ParallelIterator::Next() {
|
||||
for (int64_t i = _dimension - 1; i >= 0; --i) {
|
||||
if (_coord[i] + 1 == _shape[i]) {
|
||||
_coord[i] = 0;
|
||||
_pos -= _back_strides[i];
|
||||
} else {
|
||||
_coord[i]++;
|
||||
_pos += _strides[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The two structs is used for tag dispatch in IsNonZero.
|
||||
template <class T>
|
||||
struct is_complex_t : std::false_type {};
|
||||
template <class T>
|
||||
struct is_complex_t<std::complex<T>> : std::true_type {};
|
||||
|
||||
template <class T>
|
||||
int64_t IsNonZero(T val, std::true_type) {
|
||||
return val.real() != 0 || val.imag() != 0 ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
|
||||
}
|
||||
template <class T>
|
||||
int64_t IsNonZero(T val, std::false_type) {
|
||||
return val != static_cast<T>(0) ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
template <class T>
|
||||
uint32_t CountNonZeroComputeImpl(CpuKernelContext &ctx) {
|
||||
Tensor *x_tensor = ctx.Input(kFirstInputIndex);
|
||||
Tensor *y_tensor = ctx.Output(kFirstOutputIndex);
|
||||
const T *x_ptr = reinterpret_cast<const T *>(x_tensor->GetData());
|
||||
int64_t *y_ptr = reinterpret_cast<int64_t *>(y_tensor->GetData());
|
||||
int64_t data_num = y_tensor->NumElements();
|
||||
int64_t input_data_num = x_tensor->NumElements();
|
||||
std::vector<int64_t> input_shape = x_tensor->GetTensorShape()->GetDimSizes();
|
||||
|
||||
// For scalar_reduction, start=0, end=input_data_num, cannot be parallelized.
|
||||
auto count_nonzero_scalar_shard = [&](int64_t start, int64_t end) {
|
||||
y_ptr[0] = 0;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
y_ptr[0] += IsNonZero<T>(x_ptr[i], is_complex_t<T>{});
|
||||
}
|
||||
};
|
||||
|
||||
// For general case. Can be parallelized but performance is not good.
|
||||
auto count_nonzero_shard = [&](int64_t start, int64_t end) {
|
||||
ParallelIterator iter(cnz_transposed_shape, cnz_dims, input_shape);
|
||||
iter.Set(start * cnz_stride);
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
int64_t reduce_initial = static_cast<int64_t>(0);
|
||||
for (int64_t j = 0; j < cnz_stride; ++j) {
|
||||
reduce_initial += IsNonZero<T>(x_ptr[iter.Get()], is_complex_t<T>{});
|
||||
iter.Next();
|
||||
}
|
||||
y_ptr[i] = reduce_initial;
|
||||
}
|
||||
};
|
||||
|
||||
if (data_num == 1) {
|
||||
count_nonzero_scalar_shard(0, input_data_num);
|
||||
} else if (data_num > kParallelNum) {
|
||||
CpuKernelUtils::ParallelFor(ctx, data_num, 1, count_nonzero_shard);
|
||||
} else {
|
||||
count_nonzero_shard(0, data_num);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CountNonZeroDimsCheckAndParse(CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(kFirstInputIndex)->GetTensorShape()->GetDimSizes();
|
||||
int64_t input_rank = input_shape.size();
|
||||
std::vector<int64_t> dims{};
|
||||
|
||||
auto dims_attr = ctx.GetAttr("dims");
|
||||
if (dims_attr != nullptr) {
|
||||
dims = dims_attr->GetListInt();
|
||||
}
|
||||
if (dims.size() == 0) {
|
||||
for (int64_t i = 0; i < input_rank; ++i) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
}
|
||||
// Check dims in [-x_rank, x_rank)
|
||||
for (auto &dim : dims) {
|
||||
if (dim < 0) {
|
||||
dim += input_rank;
|
||||
}
|
||||
KERNEL_CHECK_FALSE(dim < input_rank && dim >= 0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"[CountNonZero] dims must be in [-x_rank, x_rank).");
|
||||
}
|
||||
std::sort(dims.begin(), dims.end());
|
||||
dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
|
||||
int64_t stride_ = static_cast<int64_t>(1);
|
||||
std::vector<int64_t> transposed_shape(input_rank);
|
||||
// axes is the transpose of indices.
|
||||
// For example, if input_rank = 5, dims = [1, 3],
|
||||
// then axes =[0, 2, 4] + [1, 3].
|
||||
// Initial value if axes is [?, ?, ?, ?, ?]
|
||||
std::vector<int64_t> axes_(input_rank);
|
||||
int64_t j = static_cast<int64_t>(0), k = static_cast<int64_t>(0);
|
||||
// Put dim indices to keep to front of axes and calculate stride.
|
||||
// After this operation, axes becomes [0, 2, 4] + [?, ?],
|
||||
// and stride becomes 1 * 2 * 4
|
||||
for (int64_t i = 0; i < input_rank; i++) {
|
||||
if (j == static_cast<int64_t>(dims.size()) || i != dims[j]) {
|
||||
axes_[k] = i;
|
||||
++k;
|
||||
} else {
|
||||
stride_ *= input_shape[i];
|
||||
++j;
|
||||
}
|
||||
}
|
||||
// Put dim indices to reduce to back of axes.
|
||||
// After this operation, axes becomes [0, 2, 4] + [1, 3]
|
||||
for (auto &dim : dims) {
|
||||
axes_[k] = dim;
|
||||
++k;
|
||||
}
|
||||
// Calculate transposed_shape using axes.
|
||||
// For example, if input_shape = (3, 4, 5, 6, 7), axes = [0, 2, 4, 1, 3],
|
||||
// then transposed_shape = (3, 5, 7) + (4, 6)
|
||||
std::vector<int64_t> transposed_shape_(input_rank);
|
||||
for (int64_t i = 0; i < input_rank; ++i) {
|
||||
transposed_shape_[i] = input_shape[axes_[i]];
|
||||
}
|
||||
// Assign values.
|
||||
cnz_stride = stride_, cnz_transposed_shape = transposed_shape_, cnz_dims = axes_;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CountNonZeroCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kCountNonZeroInputNum, kCountNonZeroOutputNum),
|
||||
"[%s] check input and output failed.", kCountNonZero);
|
||||
KERNEL_HANDLE_ERROR(CountNonZeroDimsCheckAndParse(ctx), "[%s] check & parse dims failed.", kCountNonZero);
|
||||
auto y_data_type = ctx.Output(kFirstOutputIndex)->GetDataType();
|
||||
KERNEL_CHECK_FALSE(y_data_type == DT_INT64, KERNEL_STATUS_PARAM_INVALID,
|
||||
"[CountNonZero] Data type of output not supported, which is [%s].", DTypeStr(y_data_type).c_str());
|
||||
auto x_data_type = ctx.Input(kFirstInputIndex)->GetDataType();
|
||||
switch (x_data_type) {
|
||||
case DT_INT8:
|
||||
return CountNonZeroComputeImpl<int8_t>(ctx);
|
||||
break;
|
||||
case DT_INT16:
|
||||
return CountNonZeroComputeImpl<int16_t>(ctx);
|
||||
break;
|
||||
case DT_INT32:
|
||||
return CountNonZeroComputeImpl<int32_t>(ctx);
|
||||
break;
|
||||
case DT_INT64:
|
||||
return CountNonZeroComputeImpl<int64_t>(ctx);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
return CountNonZeroComputeImpl<uint8_t>(ctx);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
return CountNonZeroComputeImpl<uint16_t>(ctx);
|
||||
break;
|
||||
case DT_UINT32:
|
||||
return CountNonZeroComputeImpl<uint32_t>(ctx);
|
||||
break;
|
||||
case DT_UINT64:
|
||||
return CountNonZeroComputeImpl<uint64_t>(ctx);
|
||||
break;
|
||||
case DT_FLOAT16:
|
||||
return CountNonZeroComputeImpl<Eigen::half>(ctx);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
return CountNonZeroComputeImpl<float>(ctx);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
return CountNonZeroComputeImpl<double>(ctx);
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
return CountNonZeroComputeImpl<std::complex<float>>(ctx);
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
return CountNonZeroComputeImpl<std::complex<double>>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[CountNonZero] kernel data type [%s] not support.", DTypeStr(x_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kCountNonZero, CountNonZeroCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_COUNT_NONZERO_H
|
||||
#define AICPU_KERNELS_NORMALIZED_COUNT_NONZERO_H
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CountNonZeroCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~CountNonZeroCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,124 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "csr_sparse_matrix_to_dense.h"
|
||||
#include <securec.h>
|
||||
#include <complex>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "utils/allocator_utils.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 5;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *CSRSparseMatrixToDense = "CSRSparseMatrixToDense";
|
||||
|
||||
#define SWITCH_CASE(_IDX_T, _VALUE_T, VALUE_T, FLAG, CTX) \
|
||||
case _VALUE_T: \
|
||||
switch (_IDX_T) { \
|
||||
case DT_INT32: \
|
||||
(FLAG) = DoCompute<int32_t, VALUE_T>(CTX); \
|
||||
break; \
|
||||
case DT_INT64: \
|
||||
(FLAG) = DoCompute<int64_t, VALUE_T>(CTX); \
|
||||
break; \
|
||||
default: \
|
||||
KERNEL_LOG_ERROR("CSRSparseMatrixToDense index type [%s] not support.", DTypeStr(_IDX_T).c_str()); \
|
||||
return KERNEL_STATUS_PARAM_INVALID; \
|
||||
} \
|
||||
break;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CSRSparseMatrixToDenseCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CSRSparseMatrixToDense normal check failed.");
|
||||
DataType indice_type = ctx.Input(0)->GetDataType();
|
||||
DataType value_type = ctx.Input(4)->GetDataType();
|
||||
uint32_t status = 0;
|
||||
switch (value_type) {
|
||||
SWITCH_CASE(indice_type, DT_FLOAT, float_t, status, ctx)
|
||||
SWITCH_CASE(indice_type, DT_DOUBLE, double_t, status, ctx)
|
||||
SWITCH_CASE(indice_type, DT_COMPLEX64, complex<float_t>, status, ctx)
|
||||
SWITCH_CASE(indice_type, DT_COMPLEX128, complex<double_t>, status, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("CSRSparseMatrixToDense values type [%s] not support.", DTypeStr(value_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(status, "CSRSparseMatrixToDense compute failed!");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename indiceT, typename valueT>
|
||||
uint32_t CSRSparseMatrixToDenseCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
indiceT batch_size = ctx.Input(1)->NumElements() - 1;
|
||||
auto rank = ctx.Input(0)->NumElements();
|
||||
int shift = (rank == 2) ? 0 : 1;
|
||||
indiceT num_rows = *(static_cast<indiceT *>(ctx.Input(0)->GetData()) + shift);
|
||||
indiceT num_cols = *(static_cast<indiceT *>(ctx.Input(0)->GetData()) + shift + 1);
|
||||
indiceT *batch_ptrs = static_cast<indiceT *>(ctx.Input(1)->GetData());
|
||||
indiceT *row_ptrs = static_cast<indiceT *>(ctx.Input(2)->GetData());
|
||||
indiceT *col_ind = static_cast<indiceT *>(ctx.Input(3)->GetData());
|
||||
valueT *values = static_cast<valueT *>(ctx.Input(4)->GetData());
|
||||
auto output = ctx.Output(0);
|
||||
auto output_shape = output->GetTensorShape();
|
||||
if (rank == 2) {
|
||||
output_shape->SetDimSizes({num_rows, num_cols});
|
||||
} else {
|
||||
output_shape->SetDimSizes({batch_size, num_rows, num_cols});
|
||||
}
|
||||
output->SetTensorShape(output_shape.get());
|
||||
valueT *y_data = static_cast<valueT *>(ctx.Output(0)->GetData());
|
||||
// use multi-thread
|
||||
uint32_t min_core = 1;
|
||||
uint32_t max_core = std::max(min_core, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
max_core = std::min(max_core, (uint32_t)batch_size);
|
||||
if (max_core == 0) {
|
||||
KERNEL_LOG_ERROR("Max core num cannot be zero.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto shard = [&](int64_t start, int64_t end) {
|
||||
for (int64_t batch_idx = start; batch_idx < end; batch_idx++) {
|
||||
const int64_t dense_offset = batch_idx * num_rows * num_cols;
|
||||
for (int64_t i = 0; i < num_rows * num_cols; ++i) {
|
||||
y_data[dense_offset + i] = 0;
|
||||
}
|
||||
const int64_t csr_batch_offset = batch_ptrs[batch_idx];
|
||||
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
const int64_t row_offset = batch_idx * (num_rows + 1) + row_idx;
|
||||
const int64_t col_begin = row_ptrs[row_offset];
|
||||
const int64_t col_end = row_ptrs[row_offset + 1];
|
||||
for (int64_t i = col_begin; i < col_end; ++i) {
|
||||
const int64_t col_idx = col_ind[csr_batch_offset + i];
|
||||
y_data[dense_offset + (row_idx * num_cols) + col_idx] = values[csr_batch_offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, batch_size / max_core, shard),
|
||||
"CSRSparseMatrixToDense Compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
// register the opetaor
|
||||
REGISTER_CPU_KERNEL(CSRSparseMatrixToDense, CSRSparseMatrixToDenseCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CSR_SPARSE_MATRIX_TO_DENSE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CSR_SPARSE_MATRIX_TO_DENSE_H_
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class CSRSparseMatrixToDenseCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~CSRSparseMatrixToDenseCpuKernel() = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename indiceT, typename valueT>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,229 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "csr_sparse_matrix_to_sparse_tensor.h"
|
||||
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 5;
|
||||
const uint32_t kOutputNum = 3;
|
||||
const char *CSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor";
|
||||
// when input data size is more than kParallelDataNum, use Parallel func
|
||||
const int64_t kParallelDataNum = 4;
|
||||
const int64_t kParallelDataNumMid = 32;
|
||||
const int DIM2 = 2;
|
||||
const int DIM3 = 3;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CSRSparseMatrixToSparseTensorCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CSRSparseMatrixToSparseTensor normal check failed.");
|
||||
Tensor *x_dense_shape = ctx.Input(0);
|
||||
Tensor *x_batch_pointers = ctx.Input(1);
|
||||
Tensor *x_row_pointers = ctx.Input(2);
|
||||
Tensor *x_col_indices = ctx.Input(3);
|
||||
Tensor *x_values = ctx.Input(4);
|
||||
|
||||
const int rank = x_dense_shape->NumElements();
|
||||
if (rank != DIM2 && rank != DIM3) {
|
||||
KERNEL_LOG_ERROR("CSR SparseMatrix must have rank 2 or 3.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto x_row_pointers_shape = x_row_pointers->GetTensorShape();
|
||||
auto x_col_indices_shape = x_col_indices->GetTensorShape();
|
||||
auto x_values_shape = x_values->GetTensorShape();
|
||||
if (x_col_indices_shape->NumElements() != x_values_shape->NumElements()) {
|
||||
KERNEL_LOG_ERROR("Tensor x_col_indices&x_values's ranks mismatch.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto x_dense_shape_data_type = x_dense_shape->GetDataType();
|
||||
auto x_batch_pointers_data_type = x_batch_pointers->GetDataType();
|
||||
auto x_row_pointers_data_type = x_row_pointers->GetDataType();
|
||||
auto x_col_indices_data_type = x_col_indices->GetDataType();
|
||||
if (x_col_indices_data_type != DT_INT32 && x_col_indices_data_type != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel data type [%s] not support.",
|
||||
DTypeStr(x_col_indices_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (x_dense_shape_data_type != x_col_indices_data_type || x_batch_pointers_data_type != x_col_indices_data_type ||
|
||||
x_row_pointers_data_type != x_col_indices_data_type) {
|
||||
KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel data type mismatch.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto x_values_data_type = x_values->GetDataType();
|
||||
|
||||
uint32_t status;
|
||||
switch (x_col_indices_data_type) {
|
||||
case DT_INT32:
|
||||
switch (x_values_data_type) {
|
||||
case DT_FLOAT:
|
||||
status = ComputeKernel<int32_t, float>(ctx);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
status = ComputeKernel<int32_t, double>(ctx);
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
status = ComputeKernel<int32_t, std::complex<float> >(ctx);
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
status = ComputeKernel<int32_t, std::complex<double> >(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR(
|
||||
"CSRSparseMatrixToSparseTensor kernel data type [%s] not "
|
||||
"support.",
|
||||
DTypeStr(x_values_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
break;
|
||||
case DT_INT64:
|
||||
switch (x_values_data_type) {
|
||||
case DT_FLOAT:
|
||||
status = ComputeKernel<int64_t, float>(ctx);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
status = ComputeKernel<int64_t, double>(ctx);
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
status = ComputeKernel<int64_t, std::complex<float> >(ctx);
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
status = ComputeKernel<int64_t, std::complex<double> >(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR(
|
||||
"CSRSparseMatrixToSparseTensor kernel data type [%s] not "
|
||||
"support.",
|
||||
DTypeStr(x_values_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("data type of indices is not int32 or int64");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (status != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel compute failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTensorCpuKernel);
|
||||
|
||||
template <typename indicesT, typename dataT>
|
||||
uint32_t CSRSparseMatrixToSparseTensorCpuKernel::ComputeKernel(CpuKernelContext &ctx) {
|
||||
auto x_dense_shape = ctx.Input(0);
|
||||
auto x_dense_shape_ptr = static_cast<indicesT *>(x_dense_shape->GetData());
|
||||
auto dense_shape_ptr = static_cast<indicesT *>(ctx.Output(2)->GetData());
|
||||
auto values_ptr = static_cast<dataT *>(ctx.Output(1)->GetData());
|
||||
auto x_values = ctx.Input(4);
|
||||
auto x_values_ptr = static_cast<dataT *>(x_values->GetData());
|
||||
|
||||
// Copy the SparseTensor's dense_shape and values from the CSRSparseMatrix.
|
||||
for (int64_t i = 0; i < x_dense_shape->GetTensorShape()->NumElements(); i++) {
|
||||
dense_shape_ptr[i] = x_dense_shape_ptr[i];
|
||||
}
|
||||
for (int64_t i = 0; i < x_values->GetTensorShape()->NumElements(); i++) {
|
||||
values_ptr[i] = x_values_ptr[i];
|
||||
}
|
||||
|
||||
const uint32_t batch_size = ctx.Input(1)->NumElements() - 1;
|
||||
if (batch_size >= kParallelDataNum) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
|
||||
if (batch_size <= kParallelDataNumMid) {
|
||||
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
|
||||
}
|
||||
|
||||
if (max_core_num > batch_size) {
|
||||
max_core_num = batch_size;
|
||||
}
|
||||
|
||||
auto sharder = [&](int64_t batch_begin, int64_t batch_end) {
|
||||
SpecialCompute<indicesT>(batch_begin, batch_end, ctx);
|
||||
};
|
||||
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0.");
|
||||
}
|
||||
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, batch_size / max_core_num, sharder),
|
||||
"CSRSparseMatrixToSparseTensor Compute failed.");
|
||||
} else {
|
||||
SpecialCompute<indicesT>(0, batch_size, ctx);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename indicesT>
|
||||
void CSRSparseMatrixToSparseTensorCpuKernel::SpecialCompute(int64_t batch_begin, int64_t batch_end,
|
||||
CpuKernelContext &ctx) {
|
||||
auto x_dense_shape = ctx.Input(0);
|
||||
const int rank = x_dense_shape->NumElements();
|
||||
auto x_dense_shape_ptr = static_cast<indicesT *>(x_dense_shape->GetData());
|
||||
const int64_t num_rows = x_dense_shape_ptr[(rank == DIM2) ? 0 : 1];
|
||||
auto x_batch_pointers_ptr = static_cast<indicesT *>(ctx.Input(1)->GetData());
|
||||
auto x_row_pointers_ptr = static_cast<indicesT *>(ctx.Input(2)->GetData());
|
||||
auto x_col_indices_ptr = static_cast<indicesT *>(ctx.Input(3)->GetData());
|
||||
|
||||
for (int64_t batch_idx = batch_begin; batch_idx < batch_end; ++batch_idx) {
|
||||
const int64_t batch_offset = x_batch_pointers_ptr[batch_idx];
|
||||
|
||||
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
int64_t row_offset = batch_idx * (num_rows + 1) + row_idx;
|
||||
|
||||
// The column indices of the current row lie in the range:
|
||||
// [x_row_pointers_ptr[row_offset], x_row_pointer_ptr[row_offset + 1]]
|
||||
const int64_t col_begin = x_row_pointers_ptr[row_offset];
|
||||
const int64_t col_end = x_row_pointers_ptr[row_offset + 1];
|
||||
for (int64_t i = col_begin; i < col_end; ++i) {
|
||||
const int64_t col_idx = x_col_indices_ptr[batch_offset + i];
|
||||
const int64_t indices_offset = rank * (batch_offset + i);
|
||||
IndicesCompute<indicesT>(ctx, indices_offset, batch_idx, row_idx, col_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename indicesT>
|
||||
void CSRSparseMatrixToSparseTensorCpuKernel::IndicesCompute(CpuKernelContext &ctx, int64_t indices_offset,
|
||||
const int64_t batch_idx, const int64_t row_idx,
|
||||
const int64_t col_idx) {
|
||||
const int rank = ctx.Input(0)->NumElements();
|
||||
auto indices_ptr = static_cast<indicesT *>(ctx.Output(0)->GetData());
|
||||
if (rank == DIM2) {
|
||||
indices_ptr[indices_offset] = row_idx;
|
||||
indices_ptr[++indices_offset] = col_idx;
|
||||
} else { // rank == 3
|
||||
indices_ptr[indices_offset] = batch_idx;
|
||||
indices_ptr[++indices_offset] = row_idx;
|
||||
indices_ptr[++indices_offset] = col_idx;
|
||||
}
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CSRSPARSEMATRIXTOSPARSETENSOR_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CSRSPARSEMATRIXTOSPARSETENSOR_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CSRSparseMatrixToSparseTensorCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~CSRSparseMatrixToSparseTensorCpuKernel() = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename indicesT, typename dataT>
|
||||
uint32_t ComputeKernel(CpuKernelContext &ctx);
|
||||
|
||||
template <typename indicesT>
|
||||
void SpecialCompute(int64_t batch_begin, int64_t batch_end, CpuKernelContext &ctx);
|
||||
|
||||
template <typename indicesT>
|
||||
void IndicesCompute(CpuKernelContext &ctx, int64_t indices_offset, const int64_t batch_idx, const int64_t row_idx,
|
||||
const int64_t col_idx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,202 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cumprod.h"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kCumprodInputNum = 2;
|
||||
const uint32_t kCumprodOutputNum = 1;
|
||||
const int64_t paralled_data_size = 512 * 1024;
|
||||
const char *kCumprod = "Cumprod";
|
||||
#define CUMPROD_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = CumprodCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Cumprod kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CumprodCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kCumprodInputNum, kCumprodOutputNum), "[%s] check input and output failed.",
|
||||
kCumprod);
|
||||
// parse params
|
||||
KERNEL_HANDLE_ERROR(CumprodCheck(ctx), "[%s] check params failed.", kCumprod);
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
switch (data_type) {
|
||||
CUMPROD_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
|
||||
CUMPROD_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Cumprod kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t CumprodCpuKernel::CumprodCheck(CpuKernelContext &ctx) {
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.")
|
||||
KERNEL_CHECK_NULLPTR(ctx.GetAttr("exclusive"), KERNEL_STATUS_PARAM_INVALID, "get exclusive failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.GetAttr("reverse"), KERNEL_STATUS_PARAM_INVALID, "get reverse failed.");
|
||||
KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_INT32 || ctx.Input(1)->GetDataType() == DT_INT64),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"Data type of axis is not support, axis data type is [%u], only support int32 or int64.",
|
||||
ctx.Input(1)->GetDataType());
|
||||
KERNEL_CHECK_FALSE(ctx.Input(1)->NumElements() == 1, KERNEL_STATUS_PARAM_INVALID, "axis is out of shape")
|
||||
auto axis_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
int32_t axis = *axis_data;
|
||||
KERNEL_CHECK_FALSE((axis < ctx.Input(0)->GetTensorShape()->GetDims()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"axis is larger than input dims - 1");
|
||||
KERNEL_CHECK_FALSE((axis >= -ctx.Input(0)->GetTensorShape()->GetDims()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"axis is lower than -input dims");
|
||||
std::vector<int64_t> shape_input = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
std::vector<int64_t> shape_output = ctx.Output(0)->GetTensorShape()->GetDimSizes();
|
||||
KERNEL_CHECK_FALSE((shape_input.size() != 0), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input must be at least rank 1, got [%zu].", shape_input.size())
|
||||
KERNEL_CHECK_FALSE((shape_input.size() == shape_output.size()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The output shape size should be same as the output shape size")
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
uint32_t CumprodCpuKernel::CumprodCompute(CpuKernelContext &ctx) {
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto axis_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
int32_t axis = *axis_data;
|
||||
bool exclusive = ctx.GetAttr("exclusive")->GetBool();
|
||||
bool reverse = ctx.GetAttr("reverse")->GetBool();
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
auto shape = ctx.Input(0)->GetTensorShape();
|
||||
const int64_t rank = shape->GetDims();
|
||||
if (axis < 0) {
|
||||
axis += shape->GetDims();
|
||||
}
|
||||
size_t inner = 1;
|
||||
size_t outer = 1;
|
||||
size_t depth = 1;
|
||||
for (int32_t i = 0; i < rank; ++i) {
|
||||
if (i < axis) {
|
||||
inner *= shape->GetDimSize(i);
|
||||
} else if (i > axis) {
|
||||
outer *= shape->GetDimSize(i);
|
||||
} else {
|
||||
depth = shape->GetDimSize(i);
|
||||
}
|
||||
}
|
||||
int64_t data_num = ctx.Input(0)->NumElements();
|
||||
int64_t data_size = data_num * sizeof(T);
|
||||
if (data_size <= paralled_data_size) {
|
||||
for (size_t outer_index = 0; outer_index < outer; ++outer_index) {
|
||||
size_t outer_index_adj;
|
||||
if (reverse)
|
||||
outer_index_adj = (outer - 1) - outer_index;
|
||||
else
|
||||
outer_index_adj = outer_index;
|
||||
for (size_t inner_index = 0; inner_index < inner; inner_index++) {
|
||||
auto multiplier = static_cast<T>(1);
|
||||
size_t inner_index_adj;
|
||||
if (reverse)
|
||||
inner_index_adj = (inner - 1) - inner_index;
|
||||
else
|
||||
inner_index_adj = inner_index;
|
||||
for (size_t depth_index = 0; depth_index < depth; depth_index++) {
|
||||
size_t depth_index_adj;
|
||||
if (reverse)
|
||||
depth_index_adj = (depth - 1) - depth_index;
|
||||
else
|
||||
depth_index_adj = depth_index;
|
||||
size_t index = outer_index_adj;
|
||||
index += inner_index_adj * depth * outer;
|
||||
index += depth_index_adj * outer;
|
||||
if (exclusive) {
|
||||
output_data[index] = multiplier;
|
||||
multiplier *= input_data[index];
|
||||
} else {
|
||||
multiplier *= input_data[index];
|
||||
output_data[index] = multiplier;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto shard_cumprod = [&](size_t start, size_t ene) {
|
||||
for (size_t outer_index = 0; outer_index < outer; ++outer_index) {
|
||||
size_t outer_index_adj;
|
||||
if (reverse) {
|
||||
outer_index_adj = (outer - 1) - outer_index;
|
||||
} else {
|
||||
outer_index_adj = outer_index;
|
||||
}
|
||||
for (size_t inner_index = 0; inner_index < inner; inner_index++) {
|
||||
auto multiplier = static_cast<T>(1);
|
||||
size_t inner_index_adj;
|
||||
if (reverse) {
|
||||
inner_index_adj = (inner - 1) - inner_index;
|
||||
} else {
|
||||
inner_index_adj = inner_index;
|
||||
}
|
||||
for (size_t depth_index = 0; depth_index < depth; depth_index++) {
|
||||
size_t depth_index_adj;
|
||||
if (reverse) {
|
||||
depth_index_adj = (depth - 1) - depth_index;
|
||||
} else {
|
||||
depth_index_adj = depth_index;
|
||||
}
|
||||
size_t index = outer_index_adj;
|
||||
index += inner_index_adj * depth * outer;
|
||||
index += depth_index_adj * outer;
|
||||
if (exclusive) {
|
||||
output_data[index] = multiplier;
|
||||
multiplier *= input_data[index];
|
||||
} else {
|
||||
multiplier *= input_data[index];
|
||||
output_data[index] = multiplier;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
uint32_t min_core_num = 1;
|
||||
size_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (max_core_num > outer) {
|
||||
max_core_num = outer;
|
||||
}
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0");
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, outer, outer / max_core_num, shard_cumprod),
|
||||
"Cumprod Compute failed.");
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kCumprod, CumprodCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CUMPROD_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CUMPROD_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CumprodCpuKernel : public CpuKernel {
|
||||
public:
|
||||
CumprodCpuKernel() = default;
|
||||
~CumprodCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t CumprodCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t CumprodCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -208,4 +208,4 @@ uint32_t CumulativeLogsumexpCpuKernel::CumulativeLogsumexpCompute(CpuKernelConte
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(KCumulativeLogsumexp, CumulativeLogsumexpCpuKernel);
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -35,4 +35,4 @@ class CumulativeLogsumexpCpuKernel : public CpuKernel {
|
|||
uint32_t CumulativeLogsumexpCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
#endif
|
|
@ -0,0 +1,166 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "shuffle_channel.h"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 1;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kShuffleChannel = "ShuffleChannel";
|
||||
const int64_t minDimSize = 3;
|
||||
|
||||
#define SHUFFLE_CHANNEL_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = ShuffleChannelCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Shuffle Channel kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ShuffleChannelCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "ShuffleChannel check input and output number failed.");
|
||||
KERNEL_HANDLE_ERROR(ShuffleChannelParamCheck(ctx), "ShuffleChannel check params failed.");
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
switch (data_type) {
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
SHUFFLE_CHANNEL_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Shuffle kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ShuffleChannelCpuKernel::ShuffleChannelParamCheck(CpuKernelContext &ctx) {
|
||||
// the non null of input_0, input_1, output has been verified in NormalCheck
|
||||
Tensor *input = ctx.Input(0);
|
||||
auto group = 1;
|
||||
if (ctx.GetAttr("group")) {
|
||||
group = ctx.GetAttr("group")->GetInt();
|
||||
}
|
||||
int64_t c = input->GetTensorShape()->GetDimSize(1);
|
||||
KERNEL_CHECK_FALSE(input->GetTensorShape()->GetDims() >= minDimSize, KERNEL_STATUS_PARAM_INVALID,
|
||||
"ShuffleChannel expect input with > 2 dims.")
|
||||
KERNEL_CHECK_FALSE(group > 0, KERNEL_STATUS_PARAM_INVALID, "Number of groups to divide channels in must be positive.")
|
||||
KERNEL_CHECK_FALSE((c % group) == 0, KERNEL_STATUS_PARAM_INVALID, "Number of channels must be divisible by groups")
|
||||
KERNEL_LOG_DEBUG(
|
||||
"ShuffleChannelCpuKernel[%s], input: size[%llu];"
|
||||
"output: size[%llu].",
|
||||
ctx.GetOpType().c_str(), input->GetDataSize(), ctx.Output(0)->GetDataSize());
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ShuffleChannelCpuKernel::ShuffleChannelCompute(CpuKernelContext &ctx) {
|
||||
Tensor *input = ctx.Input(0);
|
||||
Tensor *output = ctx.Output(0);
|
||||
auto group = 1;
|
||||
if (ctx.GetAttr("group")) {
|
||||
group = ctx.GetAttr("group")->GetInt();
|
||||
}
|
||||
|
||||
auto shape = input->GetTensorShape();
|
||||
int64_t b = shape->GetDimSize(0);
|
||||
int64_t c = shape->GetDimSize(1);
|
||||
int64_t dims = shape->GetDims();
|
||||
|
||||
if (group == 0) {
|
||||
KERNEL_LOG_ERROR("group divided can not be zero");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t oc = c / group;
|
||||
|
||||
auto out = reinterpret_cast<T *>(output->GetData());
|
||||
auto in = reinterpret_cast<T *>(input->GetData());
|
||||
int64_t loc_out = 0;
|
||||
int64_t loc_in = 0;
|
||||
int64_t area = 1;
|
||||
for (int64_t i = 2; i < dims; i++) {
|
||||
area = area * shape->GetDimSize(i);
|
||||
}
|
||||
/*
|
||||
view the shape to n g c/g h*w,and transpose dim 1 and dim 2
|
||||
*/
|
||||
std::vector<int64_t> temp_shape;
|
||||
temp_shape.push_back(b);
|
||||
temp_shape.push_back(group);
|
||||
temp_shape.push_back(oc);
|
||||
temp_shape.push_back(area);
|
||||
std::vector<int64_t> temp_loc = {0, 0, 0, 0};
|
||||
bool can_plus = false;
|
||||
int64_t dim = 0;
|
||||
while (true) {
|
||||
for (dim = 0; dim <= minDimSize; dim++) {
|
||||
if (dim == minDimSize) {
|
||||
loc_in = loc_in + temp_loc[dim] * 1;
|
||||
loc_out = loc_out + temp_loc[dim] * 1;
|
||||
} else if (dim == (minDimSize - 1)) {
|
||||
loc_in = loc_in + temp_loc[dim] * area;
|
||||
loc_out = loc_out + temp_loc[1] * area;
|
||||
} else if (dim == 1) {
|
||||
loc_in = loc_in + temp_loc[dim] * oc * area;
|
||||
loc_out = loc_out + temp_loc[minDimSize - 1] * group * area;
|
||||
} else if (dim == 0) {
|
||||
loc_in = loc_in + temp_loc[dim] * group * oc * area;
|
||||
loc_out = loc_out + temp_loc[dim] * group * oc * area;
|
||||
}
|
||||
}
|
||||
*(out + loc_out) = *(in + loc_in);
|
||||
loc_in = 0;
|
||||
loc_out = 0;
|
||||
can_plus = false;
|
||||
for (dim = 0; dim <= minDimSize; dim++) {
|
||||
if (temp_loc[dim] < (temp_shape[dim] - 1)) {
|
||||
can_plus = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!can_plus) {
|
||||
break;
|
||||
}
|
||||
for (dim = minDimSize; dim >= 0; dim--) {
|
||||
if (temp_loc[dim] == (temp_shape[dim] - 1)) {
|
||||
temp_loc[dim] = 0;
|
||||
} else {
|
||||
temp_loc[dim] = temp_loc[dim] + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kShuffleChannel, ShuffleChannelCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_SHUFFLE_CHANNEL_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_SHUFFLE_CHANNEL_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ShuffleChannelCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ShuffleChannelCpuKernel() = default;
|
||||
~ShuffleChannelCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ShuffleChannelParamCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t ShuffleChannelCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -52,9 +52,24 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
kFSEDecodeOpName};
|
||||
static const std::set<std::string> kMigrateAicpuKernelOps = {mindspore::kAdaptiveAvgPool2DV1OpName,
|
||||
mindspore::kAdaptiveAvgPool2DGradV1OpName,
|
||||
mindspore::kBucketizeOpName,
|
||||
mindspore::kCacheSwapTableOpName,
|
||||
mindspore::kCauchyOpName,
|
||||
mindspore::kChannelShuffleOpName,
|
||||
mindspore::kCholeskyGradOpName,
|
||||
mindspore::kCholeskyInverseOpName,
|
||||
mindspore::kCholeskySolveOpName,
|
||||
mindspore::kCol2imOpName,
|
||||
mindspore::kCombinedNonMaxSuppressionOpName,
|
||||
mindspore::kComplexOpName,
|
||||
mindspore::kComplexAbsOpName,
|
||||
mindspore::kConcatOpName,
|
||||
mindspore::kCosOpName,
|
||||
mindspore::kCountNonZeroOpName,
|
||||
mindspore::kCumulativeLogsumexpOpName,
|
||||
mindspore::kCumprodOpName,
|
||||
mindspore::kCSRSparseMatrixToDenseOpName,
|
||||
mindspore::kCSRSparseMatrixToSparseTensorOpName,
|
||||
mindspore::kDataFormatVecPermuteOpName,
|
||||
mindspore::kFillOpName,
|
||||
mindspore::kLogMatrixDeterminantOpName,
|
||||
|
|
|
@ -184,6 +184,20 @@ from .nan_to_num import _nan_to_num_aicpu
|
|||
from .qr import _qr_aicpu
|
||||
from .col2im import _col2im_aicpu
|
||||
from .matrix_solve_ls import _matrix_solve_ls_aicpu
|
||||
from .cauchy import _cauchy_aicpu
|
||||
from .bucketize import _bucketize_aicpu
|
||||
from .channel_shuffle import _channel_shuffle_aicpu
|
||||
from .choleskygrad import _choleskygrad_aicpu
|
||||
from .cholesky_inverse import _cholesky_inverse_aicpu
|
||||
from .cholesky_solve import _cholesky_solve_aicpu
|
||||
from .combined_non_max_suppression import _combined_non_max_suppression_aicpu
|
||||
from .complex import _complex_aicpu
|
||||
from .complex_abs import _complex_abs_aicpu
|
||||
from .concat import _concat_aicpu
|
||||
from .cos import _cos_aicpu
|
||||
from .count_nonzero import _count_nonzero_aicpu
|
||||
from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
|
||||
from .cumprod import _cumprod_aicpu
|
||||
from .exp import _exp_aicpu
|
||||
from .matrix_triangular_solve import _matrix_triangular_solve_aicpu
|
||||
from .maximum_grad_grad import _maximum_grad_grad_aicpu
|
||||
|
|
Loading…
Reference in New Issue