From f00862beb2724cd345dec328b8f41c8955cfdeac Mon Sep 17 00:00:00 2001 From: gaoshuanglong Date: Thu, 5 Jan 2023 15:37:54 +0800 Subject: [PATCH] add aicpu_ops --- .jenkins/check/config/filter_cppcheck.txt | 1 + .jenkins/check/config/filter_cpplint.txt | 3 + .jenkins/check/config/whitelizard.txt | 3 + mindspore/ccsrc/include/common/utils/utils.h | 12 + .../cpu_kernel/ms_kernel/bucketize.cc | 127 +++++ .../cpu_kernel/ms_kernel/bucketize.h | 35 ++ .../aicpu_ops/cpu_kernel/ms_kernel/cauchy.cc | 111 +++++ .../aicpu_ops/cpu_kernel/ms_kernel/cauchy.h | 41 ++ .../cpu_kernel/ms_kernel/cholesky_grad.cc | 199 ++++++++ .../cpu_kernel/ms_kernel/cholesky_grad.h | 50 ++ .../cpu_kernel/ms_kernel/cholesky_inverse.cc | 88 ++++ .../cpu_kernel/ms_kernel/cholesky_inverse.h | 19 + .../cpu_kernel/ms_kernel/cholesky_solve.cc | 88 ++++ .../cpu_kernel/ms_kernel/cholesky_solve.h | 38 ++ .../ms_kernel/combined_non_max_suppression.cc | 447 ++++++++++++++++++ .../ms_kernel/combined_non_max_suppression.h | 83 ++++ .../aicpu_ops/cpu_kernel/ms_kernel/complex.cc | 97 ++++ .../aicpu_ops/cpu_kernel/ms_kernel/complex.h | 40 ++ .../cpu_kernel/ms_kernel/complex_abs.cc | 113 +++++ .../cpu_kernel/ms_kernel/complex_abs.h | 28 ++ .../aicpu_ops/cpu_kernel/ms_kernel/concat.cc | 303 ++++++++++++ .../aicpu_ops/cpu_kernel/ms_kernel/concat.h | 76 +++ .../aicpu_ops/cpu_kernel/ms_kernel/cos.cc | 123 +++++ .../aicpu_ops/cpu_kernel/ms_kernel/cos.h | 28 ++ .../cpu_kernel/ms_kernel/count_nonzero.cc | 272 +++++++++++ .../cpu_kernel/ms_kernel/count_nonzero.h | 31 ++ .../ms_kernel/csr_sparse_matrix_to_dense.cc | 124 +++++ .../ms_kernel/csr_sparse_matrix_to_dense.h | 35 ++ .../csr_sparse_matrix_to_sparse_tensor.cc | 229 +++++++++ .../csr_sparse_matrix_to_sparse_tensor.h | 41 ++ .../aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc | 202 ++++++++ .../aicpu_ops/cpu_kernel/ms_kernel/cumprod.h | 37 ++ .../ms_kernel/cumulativelogsumexp.cc | 2 +- .../ms_kernel/cumulativelogsumexp.h | 2 +- .../cpu_kernel/ms_kernel/shuffle_channel.cc | 166 +++++++ .../cpu_kernel/ms_kernel/shuffle_channel.h | 39 ++ .../optimizer/mindir/aicpu_lib_select.cc | 15 + .../mindspore/ops/_op_impl/aicpu/__init__.py | 14 + 38 files changed, 3360 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.h diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index d30d6e8b6bf..6934837b381 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -86,6 +86,7 @@ "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constVariable" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "redundantAssignment" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constArgument" +"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "unknownMacro" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "nullPointerRedundantCheck" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "variableScope" diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 000b2983610..4b003113bdd 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -132,3 +132,6 @@ "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/newline" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/operators" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/comma" +"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/indentation_namespace" +"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/line_length" + diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 7811b6397b5..4ed8424e90b 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -282,6 +282,9 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d.cc:aicpu::MaxUnpool2DCpuKernel::MaxUnpool2DCompute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_solve_ls.cc:aicpu::MatrixSolveLsCpuKernel::Compute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/col2im.cc:aicpu::Col2imCpuKernel::Col2imParamCheck +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc:aicpu::CSRSparseMatrixToSparseTensorCpuKernel::Compute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc:aicpu::CumprodCpuKernel::CumprodCompute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc:aicpu::CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCheck mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd_update.cc:aicpu::ScatterNdUpdateCpuKernel::Compute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/ragged_tensor_to_sparse.cc:aicpu::RaggedTensorToSparseCpuKernel::Compute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_3d_grad.cc:aicpu::MaxUnpool3DGradCpuKernel::MaxUnpool3DGradCompute diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index aadfb726a76..d75e55de388 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -149,15 +149,21 @@ constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2"; constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3"; constexpr auto kBpropCutOpName = "bprop_cut"; constexpr auto kBroadcastToOpName = "BroadcastTo"; +constexpr auto kBucketizeOpName = "Bucketize"; constexpr auto kIFMROpName = "IFMR"; constexpr auto kBroadcastToDOpName = "BroadcastToD"; constexpr auto kCacheSwapTableOpName = "CacheSwapTable"; constexpr auto kCallOpName = "call"; +constexpr auto kCauchyOpName = "Cauchy"; constexpr auto kCastOpName = "Cast"; constexpr auto kCentralizationOpName = "Centralization"; constexpr auto kCeLUOpName = "CeLU"; constexpr auto kCeluV2OpName = "CeluV2"; +constexpr auto kChannelShuffleOpName = "ChannelShuffle"; constexpr auto kCheckNumericsOpName = "CheckNumerics"; +constexpr auto kCholeskyGradOpName = "CholeskyGrad"; +constexpr auto kCholeskyInverseOpName = "CholeskyInverse"; +constexpr auto kCholeskySolveOpName = "CholeskySolve"; constexpr auto kClearZeroOpName = "ClearZero"; constexpr auto kClipBoxesOpName = "kClipBoxes"; constexpr auto kClipBoxesDOpName = "kClipBoxesD"; @@ -165,8 +171,11 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum"; constexpr auto kClipByValueOpName = "ClipByValue"; constexpr auto kCoalesceOpName = "Coalesce"; constexpr auto kCol2imOpName = "Col2im"; +constexpr auto kCombinedNonMaxSuppressionOpName = "CombinedNonMaxSuppression"; constexpr auto kCombineMomentumOpName = "CombineMomentum"; constexpr auto kCombineMomentumWeightOpName = "CombineMomentumWeight"; +constexpr auto kComplexOpName = "Complex"; +constexpr auto kComplexAbsOpName = "ComplexAbs"; constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; constexpr auto kConcatOpName = "Concat"; constexpr auto kConcatDOpName = "ConcatD"; @@ -193,6 +202,8 @@ constexpr auto kCropAndResizeOpName = "CropAndResize"; constexpr auto kCropAndResizeDOpName = "CropAndResizeD"; constexpr auto kConvBN1OpName = "ConvBN1"; constexpr auto kCOO2CSROpName = "COO2CSR"; +constexpr auto kCosOpName = "Cos"; +constexpr auto kCountNonZeroOpName = "CountNonZero"; constexpr auto kCSR2COOOpName = "CSR2COO"; constexpr auto kCSRDivOpName = "CSRDiv"; constexpr auto kCSRGatherOpName = "CSRGather"; @@ -200,6 +211,7 @@ constexpr auto kCSRMMOpName = "CSRMM"; constexpr auto kCSRMulOpName = "CSRMul"; constexpr auto kCSRMVOpName = "CSRMV"; constexpr auto kCSRReduceSumOpName = "CSRReduceSum"; +constexpr auto kCSRSparseMatrixToDenseOpName = "CSRSparseMatrixToDense"; constexpr auto kCSRSparseMatrixToSparseTensorOpName = "CSRSparseMatrixToSparseTensor"; constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; constexpr auto kCumprodOpName = "Cumprod"; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.cc new file mode 100644 index 00000000000..a1cc5420f7b --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "bucketize.h" + +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" + +namespace { +const uint32_t kOutputNum = 1; +const uint32_t kInputNum = 1; +const char *kBucketize = "Bucketize"; +const int64_t kParallelDataNumSameShape = 64 * 1024; +const int64_t kParallelDataNumSameShapeMid = 35 * 1024; + +#define BUCKETIZE_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = BucketizeCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Bucketize kernel compute failed."); \ + return result; \ + } \ + break; \ + } + +int64_t get_tensor_length(aicpu::Tensor *t) { + std::vector dim_sizes = t->GetTensorShape()->GetDimSizes(); + int64_t length = 1; + for (auto x : dim_sizes) { + length *= x; + } + return length; +} +} // namespace + +namespace aicpu { +uint32_t BucketizeCpuKernel::Compute(CpuKernelContext &ctx) { + // normal check + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Bucketize check input and output number failed."); + auto data_type = ctx.Input(0)->GetDataType(); + + KERNEL_CHECK_NULLPTR(ctx.GetAttr("boundaries"), KERNEL_STATUS_PARAM_INVALID, "Get boundaries failed") + + // check input datatype + Tensor *input = ctx.Input(kFirstInputIndex); + DataType dt_input = input->GetDataType(); + KERNEL_CHECK_FALSE((dt_input == DT_FLOAT || dt_input == DT_INT32 || dt_input == DT_INT64 || dt_input == DT_DOUBLE), + KERNEL_STATUS_PARAM_INVALID, + "Input data type must DT_FLOAT or DT_INT32 or DT_INT64 or DT_DOUBLE," + "but got data type[%s].", + DTypeStr(dt_input).c_str()); + + // check output datatype + Tensor *output = ctx.Output(kFirstOutputIndex); + DataType dt_output = output->GetDataType(); + KERNEL_CHECK_FALSE((dt_output == DT_INT32), KERNEL_STATUS_PARAM_INVALID, + "Output data type must DT_INT32, but got data type[%s].", DTypeStr(dt_output).c_str()); + + auto input_sizes = input->GetTensorShape()->GetDimSizes(); + auto output_sizes = output->GetTensorShape()->GetDimSizes(); + KERNEL_CHECK_FALSE((input_sizes == output_sizes), KERNEL_STATUS_PARAM_INVALID, + "The tensor shape of input [%s] need be same with " + "output [%s].", + VectorToString(input_sizes).c_str(), VectorToString(output_sizes).c_str()); + + switch (data_type) { + BUCKETIZE_COMPUTE_CASE(DT_INT32, int32_t, ctx) + BUCKETIZE_COMPUTE_CASE(DT_INT64, int64_t, ctx) + BUCKETIZE_COMPUTE_CASE(DT_FLOAT, float, ctx) + BUCKETIZE_COMPUTE_CASE(DT_DOUBLE, double, ctx) + default: + KERNEL_LOG_ERROR("Bucketize kernel data type [%s] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} + +template +uint32_t BucketizeCpuKernel::BucketizeCompute(CpuKernelContext &ctx) { + const int64_t data_num = get_tensor_length(ctx.Input(0)); + auto boundaries = ctx.GetAttr("boundaries"); + std::vector boundaries_data = boundaries->GetListFloat(); + std::sort(boundaries_data.begin(), boundaries_data.end()); + auto input_data = reinterpret_cast(ctx.Input(0)->GetData()); + auto output_data = reinterpret_cast(ctx.Output(0)->GetData()); + + if (data_num >= kParallelDataNumSameShape) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + + if (data_num <= kParallelDataNumSameShapeMid) { + max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores + } + + if (max_core_num > data_num) { + max_core_num = data_num; + } + auto sharder_bucketize = [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; i++) { + auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); + output_data[i] = first_bigger_it - boundaries_data.begin(); + } + }; + CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_bucketize); + } else { + for (int64_t i = 0; i < data_num; i++) { + auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); + output_data[i] = first_bigger_it - boundaries_data.begin(); + } + } + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kBucketize, BucketizeCpuKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.h new file mode 100644 index 00000000000..06c0188b844 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bucketize.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_BUCKETIZE_H_ +#define AICPU_KERNELS_NORMALIZED_BUCKETIZE_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class BucketizeCpuKernel : public CpuKernel { + public: + BucketizeCpuKernel() = default; + ~BucketizeCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + template + static uint32_t BucketizeCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.cc new file mode 100644 index 00000000000..9a5a4497c51 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cauchy.h" +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include +#include + +namespace { +const char *kCauchy = "Cauchy"; +const int64_t kParallelDataNums = 64 * 1024; +const uint32_t knum = 2; +} // namespace + +#define CAUCHY_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = CauchyCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Cauchy kernel compute failed."); \ + return result; \ + } \ + break; \ + } +namespace aicpu { +uint32_t CauchyCpuKernel::Compute(CpuKernelContext &ctx) { + Tensor *output_tensor = ctx.Output(0); + auto output_dtype = output_tensor->GetDataType(); + KERNEL_CHECK_NULLPTR(output_tensor, KERNEL_STATUS_INNER_ERROR, "CauchyCompute check output_tensor is nullptr."); + + switch (output_dtype) { + CAUCHY_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx) + CAUCHY_COMPUTE_CASE(DT_FLOAT, float, ctx) + default: + KERNEL_LOG_ERROR("Cauchy kernel data type [%s] not support.", DTypeStr(output_dtype).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +}; + +template +uint32_t CauchyCpuKernel::CauchyCompute(CpuKernelContext &ctx) { + AttrValue *median = ctx.GetAttr("median"); + if (median != nullptr) { + median_ = median->GetFloat(); + } + + AttrValue *sigma = ctx.GetAttr("sigma"); + if (sigma != nullptr) { + sigma_ = sigma->GetFloat(); + } + + Tensor *y_tensor = ctx.Output(0); + + AttrValue *output_size_attr = ctx.GetAttr("size"); + KERNEL_CHECK_NULLPTR(output_size_attr, KERNEL_STATUS_PARAM_INVALID, "CauchyCompute get size failed."); + std::vector output_size = ctx.GetAttr("size")->GetListInt(); + if (output_size.empty()) { + KERNEL_LOG_ERROR("CauchyCompute get size is empty."); + return KERNEL_STATUS_PARAM_INVALID; + } + + auto y_shape = y_tensor->GetTensorShape(); + y_shape->SetDimSizes(output_size); + + int64_t y_num = y_tensor->NumElements(); + KERNEL_CHECK_NULLPTR(y_tensor->GetData(), KERNEL_STATUS_INNER_ERROR, "CauchyCompute check output_data is nullptr."); + T *y_data = static_cast(y_tensor->GetData()); + std::default_random_engine generator(std::random_device{}()); + + std::cauchy_distribution cauchy_d(median_, sigma_); + uint32_t max_core_num = 1; + if (y_num >= kParallelDataNums) { + max_core_num = std::max(max_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - knum); + if (max_core_num > y_num) { + max_core_num = y_num; + } + } + + auto Cauchy_d = [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + float data = cauchy_d(generator); + y_data[i] = static_cast(data); + } + }; + + uint32_t ret = CpuKernelUtils::ParallelFor(ctx, y_num, y_num / max_core_num, Cauchy_d); + if (ret != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("CpuKernelUtils::ParallelFor failed."); + return KERNEL_STATUS_INNER_ERROR; + } + KERNEL_LOG_INFO("CauchyCpuKernel::ComputeCauchy end."); + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kCauchy, CauchyCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.h new file mode 100644 index 00000000000..273ba14508a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cauchy.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_CAUCHY_WINDOW_H_ +#define AICPU_KERNELS_NORMALIZED_CAUCHY_WINDOW_H_ +#define EIGEN_USE_THREADS +#define EIGEN_USE_SIMPLE_THREAD_POOL + +#include "cpu_ops_kernel.h" +#include "cpu_types.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace aicpu { +class CauchyCpuKernel : public CpuKernel { + public: + CauchyCpuKernel() = default; + ~CauchyCpuKernel() = default; + + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t CauchyCompute(CpuKernelContext &ctx); + + float median_ = 0.0; + float sigma_ = 1.0; +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.cc new file mode 100644 index 00000000000..4ace8ae8b03 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cholesky_grad.h" +#include +#include +#include +#include "cpu_kernel_utils.h" + +namespace { +const uint32_t kInputNum = 2; +const uint32_t kOutputNum = 1; +const char *CholeskyGrad = "CholeskyGrad"; +} // namespace + +namespace aicpu { +uint32_t CholeskyGradCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CholeskyGrad check input and output number failed."); + Tensor *input0 = ctx.Input(0); + Tensor *input1 = ctx.Input(1); + Tensor *output0 = ctx.Output(0); + if (input0->GetDataSize() == 0 || input1->GetDataSize() == 0) { + KERNEL_LOG_ERROR("[%s] Input tensor is empty tensor.", ctx.GetOpType().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + auto shape0 = input0->GetTensorShape(); + auto shape1 = input1->GetTensorShape(); + auto shape2 = output0->GetTensorShape(); + if (shape0->GetDims() != shape1->GetDims() || shape1->GetDims() != shape2->GetDims()) { + KERNEL_LOG_ERROR("[%s] Inputs and Output tensors should have same dims.", ctx.GetOpType().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + auto dims = shape0->GetDims(); + if (shape0->GetDimSize(dims - 1) != shape0->GetDimSize(dims - 2)) { + KERNEL_LOG_ERROR("[%s] Tensor input0 is not square.", ctx.GetOpType().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + if ((shape0->GetDimSize(dims - 1) != shape1->GetDimSize(dims - 1)) || + (shape0->GetDimSize(dims - 2) != shape1->GetDimSize(dims - 2))) { + KERNEL_LOG_ERROR("[%s] Tensor input0&input1's shape mismatch.", ctx.GetOpType().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + if ((shape0->GetDimSize(dims - 1) != shape2->GetDimSize(dims - 1)) || + (shape0->GetDimSize(dims - 2) != shape2->GetDimSize(dims - 2))) { + KERNEL_LOG_ERROR("[%s] Tensor input0&output0's shape mismatch.", ctx.GetOpType().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + auto data_type_0 = input0->GetDataType(); + auto data_type_1 = input1->GetDataType(); + auto data_type_2 = output0->GetDataType(); + if (data_type_0 != data_type_1 || data_type_0 != data_type_2) { + KERNEL_LOG_ERROR("[%s] Tensor data type mismatch.", ctx.GetOpType().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + if (data_type_0 != DT_FLOAT && data_type_0 != DT_DOUBLE) { + KERNEL_LOG_ERROR("CholeskyGrad kernel data type [%u] not support.", data_type_0); + return KERNEL_STATUS_PARAM_INVALID; + } + + if (data_type_0 == DT_FLOAT) { + return ComputeKernel(ctx, true); + } else { + return ComputeKernel(ctx, true); + } + return KERNEL_STATUS_OK; +} + +template +uint32_t CholeskyGradCpuKernel::ComputeKernel(CpuKernelContext &ctx, const bool &reverse) { + auto dims = ctx.Input(0)->GetTensorShape()->GetDims(); + auto lptr = reinterpret_cast(ctx.Input(0)->GetData()); + auto gradptr = reinterpret_cast(ctx.Input(1)->GetData()); + auto outputptr = reinterpret_cast(ctx.Output(0)->GetData()); + int n = ctx.Input(0)->GetTensorShape()->GetDimSize(dims - 1); + int64_t data_num = ctx.Input(0)->NumElements(); + const int64_t mat_size = n * n; + const int64_t batch = data_num / mat_size; + const int64_t kParallelDataNum = 16 * mat_size; + const int64_t kParallelDataNumMid = 72 * mat_size; + if (data_num >= kParallelDataNum) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + + if (data_num <= kParallelDataNumMid) { + max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores + } + + auto sharder_cholesky_grad = [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; i++) { + ComputeMatrix(lptr + i * mat_size, gradptr + i * mat_size, outputptr + i * mat_size, n); + } + }; + + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch, batch / max_core_num, sharder_cholesky_grad), + "CholeskyGrad Compute failed."); + + } else { + for (int64_t i = 0; i < batch; i++) { + ComputeMatrix(lptr + i * mat_size, gradptr + i * mat_size, outputptr + i * mat_size, n); + } + } + return KERNEL_STATUS_OK; +} + +template +void CholeskyGradCpuKernel::ComputeMatrix(T *lptr, T *gradptr, T *outputptr, int64_t n) { + Eigen::Matrix eigengrad(n, n); + Eigen::Matrix eigenl(n, n); + Eigen::Matrix output_matrix(n, n); + for (int i = 0; i < n * n; i++) { + *(eigengrad.data() + i) = *(gradptr + i); + *(eigenl.data() + i) = *(lptr + i); + } + + // Algorithm only depends on lower triangular half on input_matrix_l. + Eigen::Matrix input_matrix_l = + eigenl.template triangularView(); + // Algorithm only depends on lower triangular half on input_matrix_grad. + output_matrix = eigengrad.template triangularView(); + + const int64_t kMatrixSize = input_matrix_l.rows(); + const int64_t kMaxBlockSize = 32; + + for (int64_t block_end = kMatrixSize; block_end > 0; block_end -= kMaxBlockSize) { + const int64_t block_begin = std::max(int64_t{0}, block_end - kMaxBlockSize); + const int64_t block_size = block_end - block_begin; + const int64_t trailing_size = kMatrixSize - block_end; + + auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin); + auto B_bar = output_matrix.block(block_end, 0, trailing_size, block_begin); + + auto C = input_matrix_l.block(block_end, block_begin, trailing_size, block_size); + auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, block_size); + + auto D = input_matrix_l.block(block_begin, block_begin, block_size, block_size); + auto D_bar = output_matrix.block(block_begin, block_begin, block_size, block_size); + + auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin); + auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin); + + C_bar = D.adjoint().template triangularView().solve(C_bar.adjoint()).adjoint(); + D_bar -= (C_bar.adjoint() * C).template triangularView(); + B_bar -= C_bar * R; + R_bar -= C_bar.adjoint() * B; + CholeskyGradUnblocked(D, D_bar); + R_bar -= (D_bar + D_bar.adjoint()) * R; + } + output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + *(outputptr + i * n + j) = output_matrix(i, j); + } + } +} + +template +void CholeskyGradCpuKernel::CholeskyGradUnblocked( + const Eigen::Ref> &l_block, + Eigen::Ref> grad_block) { + const int64_t kMatrixSize = l_block.rows(); + for (int64_t k = kMatrixSize - 1; k >= 0; k--) { + const int64_t number_rows_B = kMatrixSize - (k + 1); + const int64_t number_rows_r_stack_B = number_rows_B + 1; + + auto r = l_block.block(k, 0, 1, k); + auto r_bar = grad_block.block(k, 0, 1, k); + auto d = l_block(k, k); // This needs to be a scalar rather than a view. + auto d_bar = grad_block.block(k, k, 1, 1); + // B is not included explicitly because it is not used on its own. + auto B_bar = grad_block.block(k + 1, 0, number_rows_B, k); + auto c = l_block.block(k + 1, k, number_rows_B, 1); + auto c_bar = grad_block.block(k + 1, k, number_rows_B, 1); + // Result of vertical stacking d_bar and c_bar. + auto d_stack_c_bar = grad_block.block(k, k, number_rows_r_stack_B, 1); + // Result of vertical stacking of r and B. + auto r_stack_B = l_block.block(k, 0, number_rows_r_stack_B, k); + d_bar -= (c.adjoint() * c_bar) / d; + d_stack_c_bar /= d; + r_bar -= d_stack_c_bar.adjoint() * r_stack_B; + B_bar -= c_bar * r; + d_bar /= 2.; + } +} + +REGISTER_CPU_KERNEL(CholeskyGrad, CholeskyGradCpuKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.h new file mode 100644 index 00000000000..9baca38ede5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_grad.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKY_GRAD_H_ +#define AICPU_KERNELS_NORMALIZED_CHOLESKY_GRAD_H_ + +#include "cpu_ops_kernel.h" +#include "cpu_types.h" +#include "utils/bcast.h" +#include +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include + +namespace aicpu { +class CholeskyGradCpuKernel : public CpuKernel { + public: + CholeskyGradCpuKernel() = default; + ~CholeskyGradCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t ComputeKernel(CpuKernelContext &ctx, const bool &reverse); + + template + void ComputeMatrix(T *lptr, T *gradptr, T *outputptr, int64_t n); + + template + void CholeskyGradUnblocked( + const Eigen::Ref> &l_block, + Eigen::Ref> grad_block); +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.cc new file mode 100644 index 00000000000..948c765290a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cholesky_inverse.h" +#include "cpu_kernel_utils.h" +#include "utils/kernel_util.h" +#include +#include +namespace { +const uint32_t kOutputNum = 1; +const uint32_t kInputNum = 1; +const uint32_t dimension = 2; +const char *kCholeskyInverse = "CholeskyInverse"; +} // namespace + +namespace aicpu { +uint32_t CholeskyInverseCpuKernel::Compute(CpuKernelContext &ctx) { + // check params + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check CholeskyInverse params failed."); + Tensor *input = ctx.Input(0); + KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed."); + Tensor *output = ctx.Output(0); + auto inputShape = input->GetTensorShape(); + KERNEL_CHECK_NULLPTR(inputShape, KERNEL_STATUS_PARAM_INVALID, "Get inputShape failed."); + AttrValue *upper = ctx.GetAttr("upper"); + KERNEL_CHECK_NULLPTR(upper, KERNEL_STATUS_PARAM_INVALID, "Get upper failed."); + KERNEL_LOG_DEBUG( + "CholeskyInverseCpuKernel[%s], input: size[%llu];" + "output: size[%llu].", + ctx.GetOpType().c_str(), input->GetDataSize(), output->GetDataSize()); + auto input_dims = inputShape->GetDims(); + if (input_dims != dimension) { + KERNEL_LOG_ERROR("CholeskyInverse input dim must be 2!"); + return KERNEL_STATUS_PARAM_INVALID; + } else if (inputShape->GetDimSize(input_dims - 2) != inputShape->GetDimSize(input_dims - 1)) { + KERNEL_LOG_ERROR("CholeskyInverse input matrix must be square matrix!"); + return KERNEL_STATUS_PARAM_INVALID; + } + DataType data_type = ctx.Input(0)->GetDataType(); + switch (data_type) { + case DT_FLOAT: + return CholeskyInverseCompute(ctx); + case DT_DOUBLE: + return CholeskyInverseCompute(ctx); + default: + KERNEL_LOG_ERROR("CholeskyInverse kernel data type [%s] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} +template +uint32_t CholeskyInverseCpuKernel::CholeskyInverseCompute(CpuKernelContext &ctx) { + auto input_x = reinterpret_cast(ctx.Input(0)->GetData()); + auto output_y = reinterpret_cast(ctx.Output(0)->GetData()); + auto inputShape = ctx.Input(0)->GetTensorShape(); + int64_t n = inputShape->GetDimSize(0); + typedef Eigen::Matrix MatrixXd; + Eigen::Map A(input_x, n, n); + MatrixXd result; + AttrValue *upper = ctx.GetAttr("upper"); + bool val = upper->GetBool(); + if (val) { + result = (A.transpose() * A).inverse(); + } else { + result = (A * A.transpose()).inverse(); + } + for (int64_t i = 0; i < n; i++) { + for (int64_t j = 0; j < n; j++) { + *(output_y + i * n + j) = result(i, j); + } + } + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kCholeskyInverse, CholeskyInverseCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.h new file mode 100644 index 00000000000..86dec4606ad --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_inverse.h @@ -0,0 +1,19 @@ +#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKYINVERSE_H_ +#define AICPU_KERNELS_NORMALIZED_CHOLESKYINVERSE_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { + +class CholeskyInverseCpuKernel : public CpuKernel { + public: + CholeskyInverseCpuKernel() = default; + ~CholeskyInverseCpuKernel() = default; + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + static uint32_t CholeskyInverseCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.cc new file mode 100644 index 00000000000..1bd2946839c --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cholesky_solve.h" +#include +#include +#include +#include +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include "cpu_kernel_utils.h" + +namespace { +const uint32_t kInputNum = 2; +const uint32_t kOutputNum = 1; +const char *CholeskySolve = "CholeskySolve"; +} // namespace + +namespace aicpu { +uint32_t CholeskySolveCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CholeskySolve check input and output failed."); + + Tensor *input_x1 = ctx.Input(0); + AttrValue *upper = ctx.GetAttr("upper"); + bool upperinfo = (upper == nullptr) ? false : upper->GetBool(); + auto data_type_x1 = input_x1->GetDataType(); + + switch (data_type_x1) { + case DT_FLOAT: + return ComputeKernel(ctx, upperinfo); + case DT_DOUBLE: + return ComputeKernel(ctx, upperinfo); + default: + KERNEL_LOG_ERROR("CholeskySolve kernel data type [%s] not support.", DTypeStr(data_type_x1).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} + +REGISTER_CPU_KERNEL(CholeskySolve, CholeskySolveCpuKernel); + +template +uint32_t CholeskySolveCpuKernel::ComputeKernel(CpuKernelContext &ctx, const bool &upper) { + auto rhsptr = reinterpret_cast(ctx.Input(0)->GetData()); + auto lhsptr = reinterpret_cast(ctx.Input(1)->GetData()); + auto outptr = reinterpret_cast(ctx.Output(0)->GetData()); + size_t batch_size = 1; + std::vector dims = ctx.Input(0)->GetTensorShape()->GetDimSizes(); + size_t dimsnum = ctx.Input(0)->GetTensorShape()->GetDims(); + size_t dim = dims[dimsnum - 2]; + size_t rhs_dim = dims[dimsnum - 1]; + if (dimsnum == 3) { + batch_size = dims[dimsnum - 3]; + } + Eigen::Matrix RHS(dim, rhs_dim); + Eigen::Matrix LHS(dim, dim); + for (size_t k = 0; k < batch_size; k++) { + for (size_t i = 0; i < dim * rhs_dim; i++) { + RHS.data()[i] = rhsptr[k * dim * rhs_dim + i]; + } + for (size_t i = 0; i < dim * dim; i++) { + LHS.data()[i] = lhsptr[k * dim * dim + i]; + } + if (!upper) { + LHS.template triangularView().solveInPlace(RHS); + LHS.adjoint().template triangularView().solveInPlace(RHS); + } else { + LHS.adjoint().template triangularView().solveInPlace(RHS); + LHS.template triangularView().solveInPlace(RHS); + } + for (size_t i = 0; i < dim * rhs_dim; i++) { + outptr[k * dim * rhs_dim + i] = RHS.data()[i]; + } + } + return KERNEL_STATUS_OK; +} +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.h new file mode 100644 index 00000000000..e74fccc9d8e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cholesky_solve.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_CHOLESKY_SOLVE_H_ +#define AICPU_KERNELS_NORMALIZED_CHOLESKY_SOLVE_H_ + +#include "cpu_ops_kernel.h" +#include "cpu_types.h" +#include "utils/bcast.h" + +namespace aicpu { +class CholeskySolveCpuKernel : public CpuKernel { + public: + CholeskySolveCpuKernel() = default; + ~CholeskySolveCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t ComputeKernel(CpuKernelContext &ctx, const bool &upper); +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc new file mode 100644 index 00000000000..965f4c95aef --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.cc @@ -0,0 +1,447 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "combined_non_max_suppression.h" +#include +#include +#include +#include +#include +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" + +namespace { +const uint32_t kInputNum = 6; +const uint32_t kOutputNum = 4; +const char *kCombinedNonMaxSuppression = "CombinedNonMaxSuppression"; + +void alloc_zeros(float *arr, int arr_len) { + for (int i = 0; i < arr_len; i++) { + arr[i] = (float)0.0; + } +} + +void alloc_zeros(int *arr, int arr_len) { + for (int i = 0; i < arr_len; i++) { + arr[i] = 0; + } +} +} // namespace + +namespace aicpu { +// Normalize the diagonal to the input mode of bottom left and top right +void CombinedNonMaxSuppressionCpuKernel::regular_input2buffer(float **boxes_buffer, float *box_src, + const int class_idx) { + /** + * shape of box_src + * box_src[num_boxes*q*4] + * ways to visit box_src[i][class_idx][k] which stored by 1-dimension + * box_src[i][class_idx][k]=box_src[i*q*4+class_idx*4+k] + */ + int box_len1; + int sub_box_len1 = q * 4; + int box_len2 = (class_idx << 2); + for (int i = 0; i < num_boxes; i++) { + box_len1 = i * sub_box_len1 + box_len2; + if (box_src[box_len1] > box_src[box_len1 + 2]) { + boxes_buffer[i][0] = box_src[box_len1 + 2]; + boxes_buffer[i][2] = box_src[box_len1 + 0]; + } else { + boxes_buffer[i][0] = box_src[box_len1 + 0]; + boxes_buffer[i][2] = box_src[box_len1 + 2]; + } + if (box_src[box_len1 + 1] > box_src[box_len1 + 3]) { + boxes_buffer[i][1] = box_src[box_len1 + 3]; + boxes_buffer[i][3] = box_src[box_len1 + 1]; + } else { + boxes_buffer[i][1] = box_src[box_len1 + 1]; + boxes_buffer[i][3] = box_src[box_len1 + 3]; + } + } +} + +// Calculate the area ratio of the intersection of two squares +float CombinedNonMaxSuppressionCpuKernel::IOU(float **boxes_buffer, int i, int j) { + const float *box_a = boxes_buffer[i]; + const float *box_b = boxes_buffer[j]; + float lx, ly, rx, ry; + float w, h; + float area; + float area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); + float area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); + if (area_a <= 0 || area_b <= 0) { + return 0.0; + } + lx = box_a[0] > box_b[0] ? box_a[0] : box_b[0]; + ly = box_a[1] > box_b[1] ? box_a[1] : box_b[1]; + rx = box_a[2] < box_b[2] ? box_a[2] : box_b[2]; + ry = box_a[3] < box_b[3] ? box_a[3] : box_b[3]; + w = rx > lx ? (rx - lx) : 0; + h = ry > ly ? (ry - ly) : 0; + area = w * h; + return area / (area_a + area_b - area); +} + +/** + * if soft_nms_sigma > 0.0, soft_nms is used, means update by score=score*exp(scale*iou^2) + * if soft_nms_sigma <= 0.0, nms is used, means delete it when iou > iou_threshold + * run non max suppression per bath per class + */ +void CombinedNonMaxSuppressionCpuKernel::non_max_suppression(float **boxes_buffer, float *scores_buffer, + std::vector &selected) { + std::priority_queue pq; + for (int i = 0; i < num_boxes; i++) { + if (scores_buffer[i] > score_threshold) { + pq.push(non_max_suppression_local::score_index(i, scores_buffer[i], 0)); + } + } + + float scale = static_cast(0.0); + bool is_soft_nms = soft_nms_sigma > static_cast(0.0); + if (is_soft_nms) { + scale = static_cast(-0.5) / soft_nms_sigma; + } + + float similarity; + float original_score; + non_max_suppression_local::score_index next_si; + + while (((int)selected.size() < size_per_class) && (!pq.empty())) { + next_si = pq.top(); + original_score = next_si.score; + pq.pop(); + bool should_hard_suppress = false; + + for (int j = static_cast(selected.size()) - 1; j >= next_si.suppress_begin_index; j--) { + similarity = IOU(boxes_buffer, next_si.box_index, selected[j]); + if (is_soft_nms) { + next_si.score *= + similarity <= iou_threshold ? std::exp(scale * similarity * similarity) : static_cast(0.0); + } + if (!is_soft_nms && similarity > iou_threshold) { + should_hard_suppress = true; + break; + } + if (next_si.score <= score_threshold) break; + } + + next_si.suppress_begin_index = selected.size(); + if (!should_hard_suppress) { + if (next_si.score == original_score) { + selected.push_back(next_si.box_index); + continue; + } + if (next_si.score > score_threshold) { + pq.push(next_si); + } + } + } +} + +void CombinedNonMaxSuppressionCpuKernel::nms_perclass( + float *boxes, float *scores, std::vector &sub_result_vec, int &result_size) { + int k = 0; + int box_idx; + int boxe_len1; + int sub_box_len1 = q * 4; + int box_len2 = 0; + float **boxes_buffer = new float *[num_boxes](); + float *scores_buffer = new float[num_boxes](); + for (int i = 0; i < num_boxes; i++) { + boxes_buffer[i] = new float[4]; + } + /** + * shape of score and boxes + * score[num_boxes*num_class] + * boxes[num_boxes*q*4] + */ + if (q == 1) { + regular_input2buffer(boxes_buffer, boxes, 0); + } + for (int j = 0; j < num_class; j++) { + for (int i = 0; i < num_boxes; i++) { + scores_buffer[i] = scores[i * num_class + j]; + } + if (q > 1) { + regular_input2buffer(boxes_buffer, boxes, j); + box_len2 = j * 4; + } + std::vector selected; + non_max_suppression(boxes_buffer, scores_buffer, selected); + for (int i = 0; i < (int)selected.size(); i++) { + box_idx = selected[i]; + boxe_len1 = box_idx * sub_box_len1 + box_len2; + sub_result_vec[k++] = {box_idx, + scores_buffer[box_idx], + j, + {boxes[boxe_len1 + 0], boxes[boxe_len1 + 1], boxes[boxe_len1 + 2], boxes[boxe_len1 + 3]}}; + } + result_size += selected.size(); + } + for (int i = 0; i < num_boxes; i++) { + delete[] boxes_buffer[i]; + } + delete[] boxes_buffer; + delete[] scores_buffer; + return; +} + +uint32_t CombinedNonMaxSuppressionCpuKernel::nms_perbath(CpuKernelContext &ctx, float *boxes, float *scores, + float *nmsed_boxes, float *nmsed_scores, float *nmsed_class, + int *valid_detection) { + alloc_zeros(nmsed_boxes, num_bath * num_detection * 4); + alloc_zeros(nmsed_scores, num_bath * num_detection); + alloc_zeros(nmsed_class, num_bath * num_detection); + alloc_zeros(valid_detection, num_bath); + const float box_min = 0.0; + const float box_max = 1.0; + /** + * shape of scores and boxes: + * scores[num_bath*num_boxes*num_class] + * boxes[num_bath*num_boxes*q*4] + */ + int score_len2 = num_boxes * num_class; + int boxes_len2 = num_boxes * q * 4; + auto shard_nms = [&](size_t start, size_t end) { + for (int i = start; i < (int)end; i++) { + int per_detections = 0; + int scores_index = 0; + int result_size = 0; + std::vector result_vec(size_per_class * num_class, + {0, 0.0, 0, {0.0, 0.0, 0.0, 0.0}}); + nms_perclass(boxes + i * boxes_len2, scores + i * score_len2, result_vec, result_size); + if (!pad_per_class) { + per_detections = std::min(result_size, max_total_size); + } else { + per_detections = std::min(result_size, num_detection); + } + std::sort(result_vec.begin(), result_vec.begin() + result_size, non_max_suppression_local::result_cmp); + scores_index = i * num_detection; + for (int k = 0; k < per_detections; k++) { + if (clip_boxes) { + nmsed_boxes[(scores_index << 2) + 0] = std::max(std::min(result_vec[k].box_coord[0], box_max), box_min); + nmsed_boxes[(scores_index << 2) + 1] = std::max(std::min(result_vec[k].box_coord[1], box_max), box_min); + nmsed_boxes[(scores_index << 2) + 2] = std::max(std::min(result_vec[k].box_coord[2], box_max), box_min); + nmsed_boxes[(scores_index << 2) + 3] = std::max(std::min(result_vec[k].box_coord[3], box_max), box_min); + nmsed_scores[scores_index] = result_vec[k].score; + nmsed_class[scores_index] = (float)result_vec[k].class_idx; + } else { + nmsed_boxes[(scores_index << 2) + 0] = result_vec[k].box_coord[0]; + nmsed_boxes[(scores_index << 2) + 1] = result_vec[k].box_coord[1]; + nmsed_boxes[(scores_index << 2) + 2] = result_vec[k].box_coord[2]; + nmsed_boxes[(scores_index << 2) + 3] = result_vec[k].box_coord[3]; + nmsed_scores[scores_index] = result_vec[k].score; + nmsed_class[scores_index] = (float)result_vec[k].class_idx; + } + scores_index++; + } + valid_detection[i] = per_detections; + } + }; + uint32_t min_core_num = 1; + int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + if (max_core_num > num_bath) { + max_core_num = num_bath; + } + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_bath, num_bath / max_core_num, shard_nms), + "CombinedNonMaxSuppression Compute failed in nms_perbath stage."); + return KERNEL_STATUS_OK; +} + +uint32_t CombinedNonMaxSuppressionCpuKernel::Compute(CpuKernelContext &ctx) { + // check params + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), + "CombinedNonMaxSuppression check input and output number failed."); + KERNEL_HANDLE_ERROR(CombinedNonMaxSuppressionCheck(ctx), "CombinedNonMaxSuppression check params failed."); + CombinedNonMaxSuppressionCompute(ctx); + return KERNEL_STATUS_OK; +} + +uint32_t CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCheck(CpuKernelContext &ctx) { + KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Input(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Input(2)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 2 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Input(3)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 3 data failed."); + if (ctx.Input(4) != nullptr) { + KERNEL_CHECK_NULLPTR(ctx.Input(4)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 4 data failed."); + } + KERNEL_CHECK_NULLPTR(ctx.Input(5)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 5 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 0 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Output(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 1 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Output(2)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 2 data failed."); + KERNEL_CHECK_NULLPTR(ctx.Output(3)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 3 data failed."); + KERNEL_CHECK_FALSE((ctx.Input(0)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of input0 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(0)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of input1 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(1)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Input(2)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID, + "The data type of input2 [%s] must be [DT_INT32].", DTypeStr(ctx.Input(2)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Input(3)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID, + "The data type of input3 [%s] must be [DT_INT32].", DTypeStr(ctx.Input(3)->GetDataType()).c_str()); + if (ctx.Input(4) != NULL) { + KERNEL_CHECK_FALSE((ctx.Input(4)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of input4 [%s] must be [DT_FLOAT].", + DTypeStr(ctx.Input(4)->GetDataType()).c_str()); + } + KERNEL_CHECK_FALSE((ctx.Input(5)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of input5 [%s] must be [DT_FLOAT].", DTypeStr(ctx.Input(5)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Output(0)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of output0 [%s] must be [DT_FLOAT].", + DTypeStr(ctx.Output(0)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Output(1)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of output1 [%s] must be [DT_FLOAT].", + DTypeStr(ctx.Output(1)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Output(2)->GetDataType() == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID, + "The data type of output2 [%s] must be [DT_FLOAT].", + DTypeStr(ctx.Output(2)->GetDataType()).c_str()); + KERNEL_CHECK_FALSE((ctx.Output(3)->GetDataType() == DT_INT32), KERNEL_STATUS_PARAM_INVALID, + "The data type of output3 [%s] must be [DT_INT32].", + DTypeStr(ctx.Output(3)->GetDataType()).c_str()); + auto input0_shape = ctx.Input(0)->GetTensorShape(); + auto input1_shape = ctx.Input(1)->GetTensorShape(); + auto input2_shape = ctx.Input(2)->GetTensorShape(); + auto input3_shape = ctx.Input(3)->GetTensorShape(); + auto input5_shape = ctx.Input(5)->GetTensorShape(); + KERNEL_CHECK_FALSE((input0_shape->GetDims() == 4), KERNEL_STATUS_PARAM_INVALID, "The input0's dims [%d] must be 4", + input0_shape->GetDims()); + KERNEL_CHECK_FALSE((input1_shape->GetDims() == 3), KERNEL_STATUS_PARAM_INVALID, "The input1's dims [%d] must be 3", + input1_shape->GetDims()); + KERNEL_CHECK_FALSE( + (input2_shape->GetDims() == 0 || (input2_shape->GetDims() == 1 && input2_shape->GetDimSize(0) == 1)), + KERNEL_STATUS_PARAM_INVALID, "The input2's dims [%d] must be 0 or 1x1", input2_shape->GetDims()); + KERNEL_CHECK_FALSE( + (input3_shape->GetDims() == 0 || (input3_shape->GetDims() == 1 && input3_shape->GetDimSize(0) == 1)), + KERNEL_STATUS_PARAM_INVALID, "The input3's dims [%d] must be 0 or 1x1", input3_shape->GetDims()); + if (ctx.Input(4) != nullptr) { + auto input4_shape = ctx.Input(4)->GetTensorShape(); + KERNEL_CHECK_FALSE( + (input4_shape->GetDims() == 0 || (input4_shape->GetDims() == 1 && input4_shape->GetDimSize(0) == 1)), + KERNEL_STATUS_PARAM_INVALID, "The input4's dims [%d] must be 0 or 1x1", input4_shape->GetDims()); + } + KERNEL_CHECK_FALSE( + (input5_shape->GetDims() == 0 || (input5_shape->GetDims() == 1 && input5_shape->GetDimSize(0) == 1)), + KERNEL_STATUS_PARAM_INVALID, "The input5's dims [%d] must be 0 or 1x1", input5_shape->GetDims()); + auto output0_shape = ctx.Output(0)->GetTensorShape(); + auto output1_shape = ctx.Output(1)->GetTensorShape(); + auto output2_shape = ctx.Output(2)->GetTensorShape(); + auto output3_shape = ctx.Output(3)->GetTensorShape(); + KERNEL_CHECK_FALSE((output0_shape->GetDims() == 3), KERNEL_STATUS_PARAM_INVALID, "The output0's dims [%d] must be 3", + output0_shape->GetDims()); + KERNEL_CHECK_FALSE((output1_shape->GetDims() == 2), KERNEL_STATUS_PARAM_INVALID, "The output1's dims [%d] must be 2", + output1_shape->GetDims()); + KERNEL_CHECK_FALSE((output2_shape->GetDims() == 2), KERNEL_STATUS_PARAM_INVALID, "The output2's dims [%d] must be 2", + output2_shape->GetDims()); + KERNEL_CHECK_FALSE((output3_shape->GetDims() == 1), KERNEL_STATUS_PARAM_INVALID, "The output3's dims [%d] must be 1", + output3_shape->GetDims()); + KERNEL_CHECK_FALSE((input0_shape->GetDimSize(0) == input1_shape->GetDimSize(0)), KERNEL_STATUS_PARAM_INVALID, + "The input0's 1st dims [%d] need be same with the input1's 1st dims[%d]", + input0_shape->GetDimSize(0), input1_shape->GetDimSize(0)); + KERNEL_CHECK_FALSE((input0_shape->GetDimSize(1) == input1_shape->GetDimSize(1)), KERNEL_STATUS_PARAM_INVALID, + "The input0's 2nd dims [%d] need be same with the input1's 2nd dims[%d]", + input0_shape->GetDimSize(1), input1_shape->GetDimSize(1)); + KERNEL_CHECK_FALSE((input0_shape->GetDimSize(2) == input1_shape->GetDimSize(2) || input0_shape->GetDimSize(2) == 1), + KERNEL_STATUS_PARAM_INVALID, + "The input0's 3th dims [%d] need be same with the input1's 3th dims [%d] or 1", + input0_shape->GetDimSize(2), output1_shape->GetDimSize(2)); + KERNEL_CHECK_FALSE((input0_shape->GetDimSize(3) == 4), KERNEL_STATUS_PARAM_INVALID, + "The input0's 4th dims [%d] need be same with 4", input0_shape->GetDimSize(1)); + KERNEL_CHECK_FALSE((output0_shape->GetDimSize(0) == output1_shape->GetDimSize(0) && + output0_shape->GetDimSize(0) == output2_shape->GetDimSize(0) && + output0_shape->GetDimSize(0) == output3_shape->GetDimSize(0)), + KERNEL_STATUS_PARAM_INVALID, + "The input0's 1st dims [%d], input1's 1st dims [%d]," + " input2's 1st dims [%d], input3's 1st dims [%d], need be same with each other", + output0_shape->GetDimSize(0), output1_shape->GetDimSize(0), output2_shape->GetDimSize(0), + output3_shape->GetDimSize(0)); + KERNEL_CHECK_FALSE((output0_shape->GetDimSize(1) == output1_shape->GetDimSize(1) && + output0_shape->GetDimSize(1) == output2_shape->GetDimSize(1)), + KERNEL_STATUS_PARAM_INVALID, + "The input0's 2nd dims [%d], input1's 2nd dims [%d], input2's 2nd dims [%d]," + " need be same with each other", + output0_shape->GetDimSize(1), output1_shape->GetDimSize(1), output2_shape->GetDimSize(1)); + KERNEL_LOG_INFO( + " CombinedNonMaxSuppressionCpuKernel[%s], input0: size[%llu], " + " input1: size[%llu]", + ctx.GetOpType().c_str(), ctx.Input(0)->GetDataSize(), ctx.Input(1)->GetDataSize()); + KERNEL_LOG_INFO( + " output0: size[%llu], output1: size[%llu]," + " output2: size[%llu], output3: size[%llu].", + ctx.Output(0)->GetDataSize(), ctx.Output(1)->GetDataSize(), ctx.Output(2)->GetDataSize(), + ctx.Output(3)->GetDataSize()); + + return KERNEL_STATUS_OK; +} + +uint32_t CombinedNonMaxSuppressionCpuKernel::CombinedNonMaxSuppressionCompute(CpuKernelContext &ctx) { + float *boxes = reinterpret_cast(ctx.Input(0)->GetData()); + float *scores = reinterpret_cast(ctx.Input(1)->GetData()); + max_output_size_per_class = *(reinterpret_cast(ctx.Input(2)->GetData())); + max_total_size = *(reinterpret_cast(ctx.Input(3)->GetData())); + iou_threshold = *(reinterpret_cast(ctx.Input(4)->GetData())); + score_threshold = *(reinterpret_cast(ctx.Input(5)->GetData())); + num_bath = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(0)); + num_boxes = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(1)); + q = (int)(ctx.Input(0)->GetTensorShape()->GetDimSize(2)); + num_class = (int)(ctx.Input(1)->GetTensorShape()->GetDimSize(2)); + pad_per_class = false; + clip_boxes = true; + if (ctx.GetAttr("pad_per_class") != nullptr) { + pad_per_class = (bool)(ctx.GetAttr("pad_per_class")->GetBool()); + } + if (ctx.GetAttr("clip_boxes") != nullptr) { + clip_boxes = (bool)(ctx.GetAttr("clip_boxes")->GetBool()); + } + float *nmsed_boxes = reinterpret_cast(ctx.Output(0)->GetData()); + float *nmsed_scores = reinterpret_cast(ctx.Output(1)->GetData()); + float *nmsed_class = reinterpret_cast(ctx.Output(2)->GetData()); + int *valid_detection = reinterpret_cast(ctx.Output(3)->GetData()); + auto output0_shape = ctx.Output(0)->GetTensorShape(); + auto output1_shape = ctx.Output(1)->GetTensorShape(); + auto output2_shape = ctx.Output(2)->GetTensorShape(); + auto output3_shape = ctx.Output(3)->GetTensorShape(); + size_per_class = max_output_size_per_class < num_boxes ? max_output_size_per_class : num_boxes; + soft_nms_sigma = 0.0; + if (pad_per_class) { + num_detection = std::min(max_total_size, max_output_size_per_class * num_class); + } else { + num_detection = max_total_size; + } + KERNEL_CHECK_FALSE((output0_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID, + "The output0's 1nd dims [%d] must be [%d]", output0_shape->GetDimSize(0), num_bath); + KERNEL_CHECK_FALSE((output1_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID, + "The output0's 1nd dims [%d] must be [%d]", output1_shape->GetDimSize(0), num_bath); + KERNEL_CHECK_FALSE((output2_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID, + "The output0's 1nd dims [%d] must be [%d]", output2_shape->GetDimSize(0), num_bath); + KERNEL_CHECK_FALSE((output3_shape->GetDimSize(0) == num_bath), KERNEL_STATUS_PARAM_INVALID, + "The output0's 1nd dims [%d] must be [%d]", output3_shape->GetDimSize(0), num_bath); + KERNEL_CHECK_FALSE((max_output_size_per_class > 0), KERNEL_STATUS_PARAM_INVALID, + "max_output_size_per_class [%d] must be > 0", max_output_size_per_class); + KERNEL_CHECK_FALSE((max_total_size > 0), KERNEL_STATUS_PARAM_INVALID, "max_total_size [%d] must be > 0", + max_total_size); + KERNEL_CHECK_FALSE((iou_threshold >= 0 && iou_threshold <= 1), KERNEL_STATUS_PARAM_INVALID, + "iou_threshold [%f] must be in [0,1]", iou_threshold); + KERNEL_CHECK_FALSE(((int)output0_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID, + "The output0's 2nd dims [%d] need be same with %d", output0_shape->GetDimSize(1), num_detection); + KERNEL_CHECK_FALSE(((int)output1_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID, + "The output1's 2nd dims [%d] need be same with %d", output1_shape->GetDimSize(1), num_detection); + KERNEL_CHECK_FALSE(((int)output2_shape->GetDimSize(1) == num_detection), KERNEL_STATUS_PARAM_INVALID, + "The output2's 2nd dims [%d] need be same with %d", output2_shape->GetDimSize(1), num_detection); + nms_perbath(ctx, boxes, scores, nmsed_boxes, nmsed_scores, nmsed_class, valid_detection); + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kCombinedNonMaxSuppression, CombinedNonMaxSuppressionCpuKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.h new file mode 100644 index 00000000000..e1d53c056d0 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/combined_non_max_suppression.h @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_COMBINEDNONMAXSUPPRESSION_H_ +#define AICPU_KERNELS_NORMALIZED_COMBINEDNONMAXSUPPRESSION_H_ + +#include "cpu_ops_kernel.h" +#include "utils/bcast.h" + +namespace non_max_suppression_local { +struct score_index { + int box_index; + float score; + int suppress_begin_index; + score_index() {} + score_index(int bi, float s, int sbi) : box_index(bi), score(s), suppress_begin_index(sbi) {} + bool operator<(const score_index &b) const { + return (score < b.score) || ((score == b.score) && (box_index > b.box_index)); + } +}; + +struct result_para { + int box_index; + float score; + int class_idx; + float box_coord[4]; +}; + +bool result_cmp(result_para &a, result_para &b) { return a.score > b.score; } +} // namespace non_max_suppression_local + +namespace aicpu { + +class CombinedNonMaxSuppressionCpuKernel : public CpuKernel { + public: + CombinedNonMaxSuppressionCpuKernel() = default; + ~CombinedNonMaxSuppressionCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + uint32_t CombinedNonMaxSuppressionCheck(CpuKernelContext &ctx); + + uint32_t CombinedNonMaxSuppressionCompute(CpuKernelContext &ctx); + uint32_t nms_perbath(CpuKernelContext &, float *, float *, float *, float *, float *, int *); + void regular_input2buffer(float **, float *, const int); + float IOU(float **, int, int); + void non_max_suppression(float **, float *, std::vector &); + void nms_perclass(float *, float *, std::vector &, int &); + int num_bath; + int num_boxes; + int q; + int num_class; + // per batch size + int num_detection; + int max_total_size; + // The length of each type of selection defined by the user + int max_output_size_per_class; + // Calculation num_detection length + int size_per_class; + // When lower than a score_threshold, delete the relevant box + float score_threshold; + // When it is higher than the threshold value, according to the soft_nms_sigma determines deletion or decay + float iou_threshold; + float soft_nms_sigma; + bool pad_per_class; + bool clip_boxes; +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.cc new file mode 100644 index 00000000000..85379c6d7a5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "complex.h" +#include "Eigen/Eigen" +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" + +namespace { +const uint32_t kOutputNum = 1; +const uint32_t kInputNum = 2; +const char *kComplex = "Complex"; +constexpr int64_t kFloatMaxNums = 8 * 128 * 1024; +constexpr int64_t kDoubleMaxNums = 16 * 128 * 1024; +#define Complex_COMPUTE_CASE(IN_DTYPE, IN_TYPE, OUT_DTYPE, CTX) \ + case (IN_DTYPE): { \ + switch (OUT_DTYPE) { \ + case (DT_COMPLEX64): { \ + uint32_t result = ComplexCompute>(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Complex kernel compute failed."); \ + return result; \ + } \ + break; \ + } \ + case (DT_COMPLEX128): { \ + uint32_t result = ComplexCompute>(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Complex kernel compute failed."); \ + return result; \ + } \ + break; \ + } \ + default: \ + KERNEL_LOG_ERROR("Complex kernel output data type [%s] not support.", DTypeStr(OUT_DTYPE).c_str()); \ + return KERNEL_STATUS_PARAM_INVALID; \ + } \ + break; \ + } +} // namespace + +namespace aicpu { +uint32_t ComplexCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kComplex); + DataType input_type = ctx.Input(0)->GetDataType(); + switch (input_type) { + Complex_COMPUTE_CASE(DT_FLOAT, float, DT_COMPLEX64, ctx) + Complex_COMPUTE_CASE(DT_DOUBLE, double, DT_COMPLEX128, ctx) default + : KERNEL_LOG_ERROR("Complex kernel input data type [%s] not support.", DTypeStr(input_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} +template +uint32_t ComplexCpuKernel::ComplexCompute(CpuKernelContext &ctx) { + auto input0 = reinterpret_cast(ctx.Input(0)->GetData()); + auto input1 = reinterpret_cast(ctx.Input(1)->GetData()); + auto output = reinterpret_cast(ctx.Output(0)->GetData()); + auto data_type = ctx.Input(0)->GetDataType(); + int64_t data_num = ctx.Output(0)->NumElements(); + int64_t data_size = data_num * sizeof(T); + if ((data_type == DT_FLOAT && data_size <= kFloatMaxNums) || + (data_type == DT_DOUBLE && data_size <= kDoubleMaxNums)) { + for (int64_t index = 0; index < data_num; ++index) { + *(output + index) = t(*(input0 + index), *(input1 + index)); + } + } else { + uint32_t min_core_num = 1; + int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + if (max_core_num > data_num) { + max_core_num = data_num; + } + auto shard_complex = [&](size_t start, size_t end) { + for (size_t index = start; index < end; ++index) { + *(output + index) = t(*(input0 + index), *(input1 + index)); + } + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_complex), + "complex Compute failed"); + } + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kComplex, ComplexCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.h new file mode 100644 index 00000000000..4cb967cad2f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex.h @@ -0,0 +1,40 @@ +/** + * Copyright (C) 2020-2021. Huawei Technologies Co., Ltd. All rights reserved. + + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + * + * @brief + * + * @version 1.0 + * + */ + +#ifndef AICPU_KERNELS_NORMALIZED_COMPLEX_H_ +#define AICPU_KERNELS_NORMALIZED_COMPLEX_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class ComplexCpuKernel : public CpuKernel { + public: + ComplexCpuKernel() = default; + ~ComplexCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + uint32_t ComplexCheck(CpuKernelContext &ctx); + + template + uint32_t ComplexCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.cc new file mode 100644 index 00000000000..be6cc0232a9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2021 Jilin University + * Copyright 2020 Huawei Technologies Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "complex_abs.h" + +#include + +#include "cpu_kernel_utils.h" +#include "cpu_types.h" +#include "kernel_log.h" +#include "status.h" +#include "utils/kernel_util.h" + +namespace { +const std::uint32_t kComplexAbsInputNum{1u}; +const std::uint32_t kComplexAbsOutputNum{1u}; +const char *kComplexAbs{"ComplexAbs"}; +const std::int64_t kComplexAbsParallelNum{32 * 1024}; +} // namespace + +namespace aicpu { +namespace detail { +template +inline const typename T::value_type ScalarComplexAbs(const T &x) { + return std::abs(x); +} +inline std::uint32_t ParallelForComplexAbs(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size, + const std::function &work) { + if (total > kComplexAbsParallelNum) + return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work); + else + work(0, total); + return KERNEL_STATUS_OK; +} +template +inline std::uint32_t ComputeComplexAbsKernel(const CpuKernelContext &ctx) { + T *input0{static_cast(ctx.Input(0)->GetData())}; + typename T::value_type *output{static_cast(ctx.Output(0)->GetData())}; + std::int64_t total{ctx.Input(0)->NumElements()}; + std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)}; + std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)}; + return ParallelForComplexAbs(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) { + std::transform(input0 + begin, input0 + end, output + begin, ScalarComplexAbs); + }); +} + +template +inline std::uint32_t ComputeComplexAbs(const CpuKernelContext &ctx) { + std::uint32_t result{ComputeComplexAbsKernel(ctx)}; + if (result != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("ComplexAbs compute failed."); + } + return result; +} + +inline std::uint32_t ExtraCheckComplexAbs(const CpuKernelContext &ctx) { + if (ctx.Input(0)->GetDataType() == ctx.Output(0)->GetDataType()) { + KERNEL_LOG_ERROR( + "The data type of the input [%s] should not be the same as the output " + "[%s].", + DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize() * 2) { + KERNEL_LOG_ERROR( + "The data size of the input [%llu] need be as twice as the output " + "[%llu].", + ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} + +inline std::uint32_t CheckComplexAbs(CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) { + return NormalCheck(ctx, kComplexAbsInputNum, kComplexAbsOutputNum) ? KERNEL_STATUS_PARAM_INVALID + : ExtraCheckComplexAbs(ctx); +} + +inline std::uint32_t ComputeComplexAbs(const CpuKernelContext &ctx) { + DataType input_type{ctx.Input(0)->GetDataType()}; + switch (input_type) { + case DT_COMPLEX64: + return ComputeComplexAbs>(ctx); + case DT_COMPLEX128: + return ComputeComplexAbs>(ctx); + default: + KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} +} // namespace detail + +std::uint32_t ComplexAbsCpuKernel::Compute(CpuKernelContext &ctx) { + return detail::CheckComplexAbs(ctx, kComplexAbsInputNum, kComplexAbsOutputNum) ? KERNEL_STATUS_PARAM_INVALID + : detail::ComputeComplexAbs(ctx); +} + +REGISTER_CPU_KERNEL(kComplexAbs, ComplexAbsCpuKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.h new file mode 100644 index 00000000000..714c393d06e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/complex_abs.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Jilin University + * Copyright 2020 Huawei Technologies Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_COMPLEX_ABS_H_ +#define AICPU_KERNELS_NORMALIZED_COMPLEX_ABS_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class ComplexAbsCpuKernel final : public CpuKernel { + virtual std::uint32_t Compute(CpuKernelContext &ctx) override final; +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.cc new file mode 100644 index 00000000000..d16cf8b8d29 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.cc @@ -0,0 +1,303 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "concat.h" + +#include "utils/kernel_util.h" + +using namespace std; + +namespace { +const char *const Concat = "Concat"; +} + +namespace aicpu { +uint32_t ConcatCpuKernel::CheckAndInitParams(CpuKernelContext &ctx) { + if (ctx.GetAttr("N") == nullptr) { + n_ = 1; + } else { + AttrValue *n_ptr = ctx.GetAttr("N"); + n_ = n_ptr->GetInt(); + } + // "x" is a list of at least 2 "tensor" objects of the same type + KERNEL_CHECK_FALSE((n_ >= 2), KERNEL_STATUS_PARAM_INVALID, "Attr N must >= 2, but got attr N[%lld]", n_); + + uint32_t input_num = ctx.GetInputsSize(); + + // input_num is n_(concat tensor num) + 1(concat_dim) + KERNEL_CHECK_FALSE((static_cast(input_num) - 1 == n_), KERNEL_STATUS_PARAM_INVALID, + "Input num must equal attr N[%lld + 1]," + "but got input num[%u]", + n_, input_num); + + Tensor *concat_dim_ptr = ctx.Input(0); + KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim failed."); + auto concat_dim_shape_ptr = concat_dim_ptr->GetTensorShape(); + KERNEL_CHECK_NULLPTR(concat_dim_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim shape failed."); + int32_t concat_dim_dims = concat_dim_shape_ptr->GetDims(); + KERNEL_CHECK_FALSE((concat_dim_dims == 0) || ((concat_dim_dims == 1) && (concat_dim_shape_ptr->NumElements() == 1)), + KERNEL_STATUS_PARAM_INVALID, "Input concat_dim should be a scalar integer, but got rank[%d].", + concat_dim_dims); + int32_t concat_dim = 0; + DataType concat_dim_data_type = concat_dim_ptr->GetDataType(); + KERNEL_CHECK_FALSE((concat_dim_data_type == DT_INT32 || concat_dim_data_type == DT_INT64), + KERNEL_STATUS_PARAM_INVALID, + "Input concat_dim data type must DT_INT32 or DT_INT64," + "but got data type[%d].", + DTypeStr(concat_dim_data_type).c_str()); + auto concat_dim_data_ptr = concat_dim_ptr->GetData(); + KERNEL_CHECK_NULLPTR(concat_dim_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input concat_dim data failed."); + if (concat_dim_data_type == DT_INT32) { + concat_dim = static_cast(*reinterpret_cast(concat_dim_data_ptr)); + } else { + concat_dim = *reinterpret_cast(concat_dim_data_ptr); + } + + Tensor *input0_ptr = ctx.Input(1); + auto input0_type_ptr = input0_ptr->GetDataType(); + for (int64_t i = 1; i < n_; i++) { + Tensor *inputi_ptr = ctx.Input(i); + auto inputi_shape_ptr = inputi_ptr->GetTensorShape(); + auto inputi_type_ptr = inputi_ptr->GetDataType(); + KERNEL_CHECK_NULLPTR(inputi_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input xi failed."); + KERNEL_CHECK_NULLPTR(inputi_shape_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input xi shape failed."); + KERNEL_CHECK_FALSE((input0_type_ptr == inputi_type_ptr), KERNEL_STATUS_PARAM_INVALID, + "Input tensor should have same type" + "but got %d and %d.", + DTypeStr(input0_type_ptr).c_str(), DTypeStr(inputi_type_ptr).c_str()); + } + auto input0_shape_ptr = input0_ptr->GetTensorShape(); + input_dims_ = input0_shape_ptr->GetDims(); + data_type_ = input0_ptr->GetDataType(); + axis_ = concat_dim < 0 ? concat_dim + input_dims_ : concat_dim; + KERNEL_CHECK_FALSE((0 <= axis_ && axis_ < input_dims_), KERNEL_STATUS_PARAM_INVALID, + "Input concat_dim need in the " + "range[%d, %d), but got %lld.", + -input_dims_, input_dims_, concat_dim); + inputs_flat_dim0_ = 1; + for (uint32_t d = 0; d < axis_; ++d) { + inputs_flat_dim0_ *= input0_shape_ptr->GetDimSize(d); + } + return KERNEL_STATUS_OK; +} + +uint32_t ConcatCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_CHECK_FALSE((CheckAndInitParams(ctx) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID, + "CheckAndInitParams failed."); + switch (data_type_) { + case DT_FLOAT16: + return DoCompute(ctx); + case DT_FLOAT: + return DoCompute(ctx); + case DT_DOUBLE: + return DoCompute(ctx); + case DT_INT8: + return DoCompute(ctx); + case DT_INT16: + return DoCompute(ctx); + case DT_INT32: + return DoCompute(ctx); + case DT_INT64: + return DoCompute(ctx); + case DT_UINT8: + return DoCompute(ctx); + case DT_UINT16: + return DoCompute(ctx); + case DT_UINT32: + return DoCompute(ctx); + case DT_UINT64: + return DoCompute(ctx); + case DT_COMPLEX64: + return DoCompute>(ctx); + case DT_COMPLEX128: + return DoCompute>(ctx); + default: + KERNEL_LOG_ERROR("unsupported datatype[%d]", DTypeStr(data_type_).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} + +template +uint32_t ConcatCpuKernel::PrepareInput(CpuKernelContext &ctx, + std::vector::ConstMatrix>> &inputs) { + inputs.reserve(n_); + output_concat_dim_ = 0; + auto input0_shape_ptr = ctx.Input(1)->GetTensorShape(); + for (uint32_t i = 0; i < n_; ++i) { + Tensor *input_i_ptr = ctx.Input(i + 1); + int64_t input_i_num = input_i_ptr->NumElements(); + if (input_i_num == 0) { + continue; + } + auto input_i_shape_ptr = input_i_ptr->GetTensorShape(); + int32_t input_i_dims = input_i_shape_ptr->GetDims(); + KERNEL_CHECK_FALSE((input_i_dims == input_dims_), KERNEL_STATUS_PARAM_INVALID, + "Ranks of inputs should match: shape[0]=%d vs. shape[%u]=%d", input_dims_, i, input_i_dims); + for (int32_t j = 0; j < input_dims_; ++j) { + int64_t dim_ij = input_i_shape_ptr->GetDimSize(j); + if (j == axis_) { + output_concat_dim_ += input_i_dims > 0 ? dim_ij : 1; + continue; + } + int64_t dim_0j = input0_shape_ptr->GetDimSize(j); + KERNEL_CHECK_FALSE((dim_0j == dim_ij), KERNEL_STATUS_PARAM_INVALID, + "Dimensions of inputs should match: shape[0][%d]=%lld vs." + "shape[%u][%d]=%lld", + j, dim_0j, i, j, dim_ij); + } + + int64_t inputs_flat_dim1 = input_i_num / inputs_flat_dim0_; + auto input_i_data_ptr = input_i_ptr->GetData(); + KERNEL_CHECK_NULLPTR(input_i_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get input x%u data failed.", i); + auto input_i = std::make_shared::ConstMatrix>(reinterpret_cast(input_i_data_ptr), + inputs_flat_dim0_, inputs_flat_dim1); + KERNEL_CHECK_NULLPTR(input_i, KERNEL_STATUS_PARAM_INVALID, "Create input x%u failed!", i); + inputs.emplace_back(std::move(input_i)); + } + + if (input_dims_ == 0) { + output_concat_dim_ = n_; + } + return KERNEL_STATUS_OK; +} + +template +uint32_t ConcatCpuKernel::PrepareOutput(CpuKernelContext &ctx, std::shared_ptr::Matrix> &output) { + Tensor *output_ptr = ctx.Output(0); + KERNEL_CHECK_NULLPTR(output_ptr, KERNEL_STATUS_PARAM_INVALID, "Get output failed."); + auto output_data_ptr = output_ptr->GetData(); + KERNEL_CHECK_NULLPTR(output_data_ptr, KERNEL_STATUS_PARAM_INVALID, "Get output data failed."); + int64_t output_num = output_ptr->NumElements(); + int64_t output_dim1 = output_num / inputs_flat_dim0_; + output = std::make_shared::Matrix>(reinterpret_cast(output_data_ptr), inputs_flat_dim0_, + output_dim1); + KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Create output matrix failed."); + return KERNEL_STATUS_OK; +} + +template +uint32_t ConcatCpuKernel::DoCompute(CpuKernelContext &ctx) { + std::vector::ConstMatrix>> inputs; + KERNEL_CHECK_FALSE((PrepareInput(ctx, inputs) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID, + "PrepareInput failed."); + std::shared_ptr::Matrix> output = nullptr; + KERNEL_CHECK_FALSE((PrepareOutput(ctx, output) == KERNEL_STATUS_OK), KERNEL_STATUS_PARAM_INVALID, + "PrepareOutput failed."); + if (inputs.size() > 0) { + return ConcatCompute(ctx, inputs, output); + } + return KERNEL_STATUS_OK; +} + +template +uint32_t ConcatCpuKernel::ConcatCompute(CpuKernelContext &ctx, + const std::vector::ConstMatrix>> &inputs, + std::shared_ptr::Matrix> &output) { + size_t num_inputs = inputs.size(); + std::vector sizes; + sizes.reserve(num_inputs); + int64_t row_size = 0; + for (const auto &input : inputs) { + sizes.push_back(input->dimension(1)); + row_size += sizes.back(); + } + uint32_t ret = KERNEL_STATUS_OK; + auto work = [&row_size, &sizes, &inputs, &output, &num_inputs, &ret](int64_t start, int64_t end) { + if (row_size == 0) { + ret = KERNEL_STATUS_PARAM_INVALID; + return; + } + int64_t skipped_rows = start / row_size; + T *out = output->data() + skipped_rows * row_size; + T *out_start = output->data() + start; + T *out_end = output->data() + end; + + // Handle partial row at start + if (out < out_start) { + for (size_t j = 0; j < num_inputs; ++j) { + ptrdiff_t size = sizes[j]; + ptrdiff_t offset = out_start - out; + if (size <= offset) { + out += size; + continue; + } + const T *inp = &(*inputs[j])(skipped_rows, 0); + if (offset > 0) { + out += offset; + inp += offset; + size -= offset; + } + size = std::min(size, out_end - out); + KERNEL_CHECK_FALSE_EXEC((size > 0), break) + size_t copy_size = size * sizeof(T); + error_t ret = memcpy_s(out, copy_size, inp, copy_size); + if (ret != EOK) { + KERNEL_LOG_ERROR("Memcpy failed."); + ret = KERNEL_STATUS_INNER_ERROR; + return; + } + out += size; + } + ++skipped_rows; + } + if (out < out_start || out > out_end) { + KERNEL_LOG_ERROR("Out[%llx] not in range[%llx, %llx)", out, out_start, out_end); + ret = KERNEL_STATUS_INNER_ERROR; + return; + } + // Copy remaining data. + std::vector inp; + inp.reserve(num_inputs); + for (const auto &input : inputs) { + inp.push_back(&(*input)(skipped_rows, 0)); + } + const int64_t dim0 = output->dimension(0); + for (int64_t i = skipped_rows; i < dim0; ++i) { + for (int64_t j = 0; j < static_cast(num_inputs); ++j) { + ptrdiff_t size = std::min(sizes[j], out_end - out); + size_t copy_size = size * sizeof(T); + auto ret = memcpy_s(out, copy_size, inp[j], copy_size); + if (ret != EOK) { + KERNEL_LOG_ERROR("Memcpy size[%zu] from inp[%llx] to out[%llx] failed.", copy_size, inp[j], out); + ret = KERNEL_STATUS_INNER_ERROR; + return; + } + out += size; + inp[j] += size; + KERNEL_CHECK_FALSE_EXEC((out != out_end), return ); + } + } + }; + const int64_t kParallelDataNumSameShapeBig = 255 * 1024; + int64_t data_num = output->size(); + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + + if (data_num >= kParallelDataNumSameShapeBig) { + max_core_num = std::min(max_core_num, 6U); + } else { + max_core_num = std::min(max_core_num, 1U); + } + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0."); + } + CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, work); + KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), KERNEL_STATUS_INNER_ERROR, "ConcatCpuKernel failed."); + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(Concat, ConcatCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.h new file mode 100644 index 00000000000..9f11024f66a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/concat.h @@ -0,0 +1,76 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_CONCAT_H_ +#define AICPU_KERNELS_NORMALIZED_CONCAT_H_ + +#include +#include + +#include "cpu_ops_kernel.h" +#include "cpu_kernel_utils.h" +#include "kernel_log.h" +#include "securec.h" +#include "status.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace aicpu { +const uint32_t NumIndices = 2; +template +struct TTypes { + // Rank-2 tensor (matrix) of scalar type T. + using Matrix = Eigen::TensorMap, Eigen::Aligned>; + + using ConstMatrix = + Eigen::TensorMap, Eigen::Aligned>; +}; + +class ConcatCpuKernel : public CpuKernel { + public: + ConcatCpuKernel() + : data_type_(DT_DOUBLE), input_dims_(0), n_(0), output_concat_dim_(0), axis_(0), inputs_flat_dim0_(0) {} + + ~ConcatCpuKernel() = default; + + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + uint32_t CheckAndInitParams(CpuKernelContext &ctx); + + template + uint32_t PrepareInput(CpuKernelContext &ctx, std::vector::ConstMatrix>> &inputs); + + template + uint32_t PrepareOutput(CpuKernelContext &ctx, std::shared_ptr::Matrix> &output); + + template + uint32_t DoCompute(CpuKernelContext &ctx); + + template + uint32_t ConcatCompute(CpuKernelContext &ctx, + const std::vector::ConstMatrix>> &inputs, + std::shared_ptr::Matrix> &output); + + private: + DataType data_type_; + int32_t input_dims_; + int64_t n_; + int64_t output_concat_dim_; + int64_t axis_; + int64_t inputs_flat_dim0_; +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.cc new file mode 100644 index 00000000000..a13bcf3dad1 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Jilin University + * Copyright 2020 Huawei Technologies Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cos.h" + +#include + +#include "cpu_kernel_utils.h" +#include "cpu_types.h" +#include "kernel_log.h" +#include "status.h" +#include "utils/kernel_util.h" + +namespace { +const std::uint32_t kCosInputNum{1u}; +const std::uint32_t kCosOutputNum{1u}; +const char *kCos{"Cos"}; +const std::int64_t kCosParallelNum{64 * 1024}; +} // namespace + +namespace aicpu { +namespace detail { +template +inline T ScalarCos(T x) { + return std::cos(x); +} + +template <> +inline Eigen::half ScalarCos(Eigen::half x) { + const Eigen::half val{static_cast(std::cos(static_cast(x)))}; + return Eigen::half_impl::isnan(val) ? Eigen::half{0.0f} : val; +} + +inline std::uint32_t ParallelForCos(const CpuKernelContext &ctx, std::int64_t total, std::int64_t per_unit_size, + const std::function &work) { + if (total > kCosParallelNum) + return aicpu::CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, work); + else + work(0, total); + return KERNEL_STATUS_OK; +} + +template +inline std::uint32_t ComputeCosKernel(const CpuKernelContext &ctx) { + T *input0{static_cast(ctx.Input(0)->GetData())}; + T *output{static_cast(ctx.Output(0)->GetData())}; + std::int64_t total{ctx.Input(0)->NumElements()}; + std::uint32_t cores{aicpu::CpuKernelUtils::GetCPUNum(ctx)}; + std::int64_t per_unit_size{total / std::min(std::max(1L, cores - 2L), total)}; + return ParallelForCos(ctx, total, per_unit_size, [&](std::int64_t begin, std::int64_t end) { + std::transform(input0 + begin, input0 + end, output + begin, ScalarCos); + }); +} + +template +inline std::uint32_t ComputeCos(const CpuKernelContext &ctx) { + std::uint32_t result{ComputeCosKernel(ctx)}; + if (result != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("Cos compute failed."); + } + return result; +} + +inline std::uint32_t ExtraCheckCos(const CpuKernelContext &ctx) { + if (ctx.Input(0)->GetDataType() != ctx.Output(0)->GetDataType()) { + KERNEL_LOG_ERROR("The data type of the input [%s] need be the same as the output [%s].", + DTypeStr(ctx.Input(0)->GetDataType()).c_str(), DTypeStr(ctx.Output(0)->GetDataType()).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize()) { + KERNEL_LOG_ERROR( + "The data size of the input [%llu] need be the same as the output " + "[%llu].", + ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} + +inline std::uint32_t CheckCos(CpuKernelContext &ctx, std::uint32_t inputs_num, std::uint32_t outputs_num) { + return NormalCheck(ctx, kCosInputNum, kCosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : ExtraCheckCos(ctx); +} + +inline std::uint32_t ComputeCos(const CpuKernelContext &ctx) { + DataType input_type{ctx.Input(0)->GetDataType()}; + switch (input_type) { + case DT_FLOAT16: + return ComputeCos(ctx); + case DT_FLOAT: + return ComputeCos(ctx); + case DT_DOUBLE: + return ComputeCos(ctx); + case DT_COMPLEX64: + return ComputeCos>(ctx); + case DT_COMPLEX128: + return ComputeCos>(ctx); + default: + KERNEL_LOG_ERROR("Unsupported input data type [%s].", DTypeStr(input_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} +} // namespace detail + +std::uint32_t CosCpuKernel::Compute(CpuKernelContext &ctx) { + return detail::CheckCos(ctx, kCosInputNum, kCosOutputNum) ? KERNEL_STATUS_PARAM_INVALID : detail::ComputeCos(ctx); +} + +REGISTER_CPU_KERNEL(kCos, CosCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.h new file mode 100644 index 00000000000..2de30f00996 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cos.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Jilin University + * Copyright 2020 Huawei Technologies Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_COS_H_ +#define AICPU_KERNELS_NORMALIZED_COS_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class CosCpuKernel final : public CpuKernel { + virtual std::uint32_t Compute(CpuKernelContext &ctx) override final; +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.cc new file mode 100644 index 00000000000..cf210b7e953 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.cc @@ -0,0 +1,272 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "count_nonzero.h" + +#include +#include + +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" + +namespace { +const uint32_t kCountNonZeroInputNum = 1; +const uint32_t kCountNonZeroOutputNum = 1; +const int64_t kParallelNum = 2 * 1024; +const char *kCountNonZero = "CountNonZero"; + +// The following code is used to handle the general case. +// Params to use in ParallelIterator construction. +std::vector cnz_dims; +std::vector cnz_transposed_shape; +int64_t cnz_stride; + +// Class def of ParallelIterator. +class ParallelIterator { + public: + ParallelIterator(std::vector transposed_shape, std::vector dims, + const std::vector &input_shape); + ~ParallelIterator() = default; + void Next(); + void Set(int64_t pos); + inline int64_t Get() const { return _pos; }; + + private: + int64_t _dimension{0}; + std::vector _coord; + std::vector _shape; + std::vector _strides; + std::vector _back_strides; + std::vector _dims; + int64_t _pos{0}; +}; + +ParallelIterator::ParallelIterator(std::vector transposed_shape, std::vector dims, + const std::vector &input_shape) + : _dimension(transposed_shape.size()), + _coord(transposed_shape.size(), 0), + _shape(transposed_shape), + _strides(transposed_shape.size(), 1), + _back_strides(transposed_shape.size(), 1), + _dims(dims), + _pos(0) { + std::vector strides(_dimension, 1); + for (int64_t i = _dimension - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * input_shape[i + 1]; + } + for (int64_t i = _dimension - 1; i >= 0; --i) { + _strides[i] = strides[_dims[i]]; + _back_strides[i] = (_shape[i] - 1) * _strides[i]; + } +} +void ParallelIterator::Set(int64_t pos) { + for (int64_t i = _dimension - 1; i >= 0 && pos != 0; --i) { + _coord[i] = pos % _shape[i]; + _pos += _coord[i] * _strides[i]; + pos /= _shape[i]; + } +} +void ParallelIterator::Next() { + for (int64_t i = _dimension - 1; i >= 0; --i) { + if (_coord[i] + 1 == _shape[i]) { + _coord[i] = 0; + _pos -= _back_strides[i]; + } else { + _coord[i]++; + _pos += _strides[i]; + break; + } + } +} + +// The two structs is used for tag dispatch in IsNonZero. +template +struct is_complex_t : std::false_type {}; +template +struct is_complex_t> : std::true_type {}; + +template +int64_t IsNonZero(T val, std::true_type) { + return val.real() != 0 || val.imag() != 0 ? static_cast(1) : static_cast(0); +} +template +int64_t IsNonZero(T val, std::false_type) { + return val != static_cast(0) ? static_cast(1) : static_cast(0); +} +} // namespace + +namespace aicpu { +template +uint32_t CountNonZeroComputeImpl(CpuKernelContext &ctx) { + Tensor *x_tensor = ctx.Input(kFirstInputIndex); + Tensor *y_tensor = ctx.Output(kFirstOutputIndex); + const T *x_ptr = reinterpret_cast(x_tensor->GetData()); + int64_t *y_ptr = reinterpret_cast(y_tensor->GetData()); + int64_t data_num = y_tensor->NumElements(); + int64_t input_data_num = x_tensor->NumElements(); + std::vector input_shape = x_tensor->GetTensorShape()->GetDimSizes(); + + // For scalar_reduction, start=0, end=input_data_num, cannot be parallelized. + auto count_nonzero_scalar_shard = [&](int64_t start, int64_t end) { + y_ptr[0] = 0; + for (int64_t i = start; i < end; ++i) { + y_ptr[0] += IsNonZero(x_ptr[i], is_complex_t{}); + } + }; + + // For general case. Can be parallelized but performance is not good. + auto count_nonzero_shard = [&](int64_t start, int64_t end) { + ParallelIterator iter(cnz_transposed_shape, cnz_dims, input_shape); + iter.Set(start * cnz_stride); + for (int64_t i = start; i < end; ++i) { + int64_t reduce_initial = static_cast(0); + for (int64_t j = 0; j < cnz_stride; ++j) { + reduce_initial += IsNonZero(x_ptr[iter.Get()], is_complex_t{}); + iter.Next(); + } + y_ptr[i] = reduce_initial; + } + }; + + if (data_num == 1) { + count_nonzero_scalar_shard(0, input_data_num); + } else if (data_num > kParallelNum) { + CpuKernelUtils::ParallelFor(ctx, data_num, 1, count_nonzero_shard); + } else { + count_nonzero_shard(0, data_num); + } + return KERNEL_STATUS_OK; +} + +uint32_t CountNonZeroDimsCheckAndParse(CpuKernelContext &ctx) { + std::vector input_shape = ctx.Input(kFirstInputIndex)->GetTensorShape()->GetDimSizes(); + int64_t input_rank = input_shape.size(); + std::vector dims{}; + + auto dims_attr = ctx.GetAttr("dims"); + if (dims_attr != nullptr) { + dims = dims_attr->GetListInt(); + } + if (dims.size() == 0) { + for (int64_t i = 0; i < input_rank; ++i) { + dims.push_back(i); + } + } + // Check dims in [-x_rank, x_rank) + for (auto &dim : dims) { + if (dim < 0) { + dim += input_rank; + } + KERNEL_CHECK_FALSE(dim < input_rank && dim >= 0, KERNEL_STATUS_PARAM_INVALID, + "[CountNonZero] dims must be in [-x_rank, x_rank)."); + } + std::sort(dims.begin(), dims.end()); + dims.erase(std::unique(dims.begin(), dims.end()), dims.end()); + int64_t stride_ = static_cast(1); + std::vector transposed_shape(input_rank); + // axes is the transpose of indices. + // For example, if input_rank = 5, dims = [1, 3], + // then axes =[0, 2, 4] + [1, 3]. + // Initial value if axes is [?, ?, ?, ?, ?] + std::vector axes_(input_rank); + int64_t j = static_cast(0), k = static_cast(0); + // Put dim indices to keep to front of axes and calculate stride. + // After this operation, axes becomes [0, 2, 4] + [?, ?], + // and stride becomes 1 * 2 * 4 + for (int64_t i = 0; i < input_rank; i++) { + if (j == static_cast(dims.size()) || i != dims[j]) { + axes_[k] = i; + ++k; + } else { + stride_ *= input_shape[i]; + ++j; + } + } + // Put dim indices to reduce to back of axes. + // After this operation, axes becomes [0, 2, 4] + [1, 3] + for (auto &dim : dims) { + axes_[k] = dim; + ++k; + } + // Calculate transposed_shape using axes. + // For example, if input_shape = (3, 4, 5, 6, 7), axes = [0, 2, 4, 1, 3], + // then transposed_shape = (3, 5, 7) + (4, 6) + std::vector transposed_shape_(input_rank); + for (int64_t i = 0; i < input_rank; ++i) { + transposed_shape_[i] = input_shape[axes_[i]]; + } + // Assign values. + cnz_stride = stride_, cnz_transposed_shape = transposed_shape_, cnz_dims = axes_; + return KERNEL_STATUS_OK; +} + +uint32_t CountNonZeroCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kCountNonZeroInputNum, kCountNonZeroOutputNum), + "[%s] check input and output failed.", kCountNonZero); + KERNEL_HANDLE_ERROR(CountNonZeroDimsCheckAndParse(ctx), "[%s] check & parse dims failed.", kCountNonZero); + auto y_data_type = ctx.Output(kFirstOutputIndex)->GetDataType(); + KERNEL_CHECK_FALSE(y_data_type == DT_INT64, KERNEL_STATUS_PARAM_INVALID, + "[CountNonZero] Data type of output not supported, which is [%s].", DTypeStr(y_data_type).c_str()); + auto x_data_type = ctx.Input(kFirstInputIndex)->GetDataType(); + switch (x_data_type) { + case DT_INT8: + return CountNonZeroComputeImpl(ctx); + break; + case DT_INT16: + return CountNonZeroComputeImpl(ctx); + break; + case DT_INT32: + return CountNonZeroComputeImpl(ctx); + break; + case DT_INT64: + return CountNonZeroComputeImpl(ctx); + break; + case DT_UINT8: + return CountNonZeroComputeImpl(ctx); + break; + case DT_UINT16: + return CountNonZeroComputeImpl(ctx); + break; + case DT_UINT32: + return CountNonZeroComputeImpl(ctx); + break; + case DT_UINT64: + return CountNonZeroComputeImpl(ctx); + break; + case DT_FLOAT16: + return CountNonZeroComputeImpl(ctx); + break; + case DT_FLOAT: + return CountNonZeroComputeImpl(ctx); + break; + case DT_DOUBLE: + return CountNonZeroComputeImpl(ctx); + break; + case DT_COMPLEX64: + return CountNonZeroComputeImpl>(ctx); + break; + case DT_COMPLEX128: + return CountNonZeroComputeImpl>(ctx); + break; + default: + KERNEL_LOG_ERROR("[CountNonZero] kernel data type [%s] not support.", DTypeStr(x_data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kCountNonZero, CountNonZeroCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.h new file mode 100644 index 00000000000..8c9e15448c5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/count_nonzero.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_COUNT_NONZERO_H +#define AICPU_KERNELS_NORMALIZED_COUNT_NONZERO_H + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class CountNonZeroCpuKernel : public CpuKernel { + public: + ~CountNonZeroCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.cc new file mode 100644 index 00000000000..161d8b45c0c --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.cc @@ -0,0 +1,124 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "csr_sparse_matrix_to_dense.h" +#include +#include +#include +#include +#include "cpu_kernel_utils.h" +#include "cpu_types.h" +#include "kernel_log.h" +#include "status.h" +#include "utils/allocator_utils.h" +#include "utils/kernel_util.h" + +using namespace std; + +namespace { +const uint32_t kInputNum = 5; +const uint32_t kOutputNum = 1; +const char *CSRSparseMatrixToDense = "CSRSparseMatrixToDense"; + +#define SWITCH_CASE(_IDX_T, _VALUE_T, VALUE_T, FLAG, CTX) \ + case _VALUE_T: \ + switch (_IDX_T) { \ + case DT_INT32: \ + (FLAG) = DoCompute(CTX); \ + break; \ + case DT_INT64: \ + (FLAG) = DoCompute(CTX); \ + break; \ + default: \ + KERNEL_LOG_ERROR("CSRSparseMatrixToDense index type [%s] not support.", DTypeStr(_IDX_T).c_str()); \ + return KERNEL_STATUS_PARAM_INVALID; \ + } \ + break; + +} // namespace + +namespace aicpu { +uint32_t CSRSparseMatrixToDenseCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CSRSparseMatrixToDense normal check failed."); + DataType indice_type = ctx.Input(0)->GetDataType(); + DataType value_type = ctx.Input(4)->GetDataType(); + uint32_t status = 0; + switch (value_type) { + SWITCH_CASE(indice_type, DT_FLOAT, float_t, status, ctx) + SWITCH_CASE(indice_type, DT_DOUBLE, double_t, status, ctx) + SWITCH_CASE(indice_type, DT_COMPLEX64, complex, status, ctx) + SWITCH_CASE(indice_type, DT_COMPLEX128, complex, status, ctx) + default: + KERNEL_LOG_ERROR("CSRSparseMatrixToDense values type [%s] not support.", DTypeStr(value_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + KERNEL_HANDLE_ERROR(status, "CSRSparseMatrixToDense compute failed!"); + return KERNEL_STATUS_OK; +} + +template +uint32_t CSRSparseMatrixToDenseCpuKernel::DoCompute(CpuKernelContext &ctx) { + indiceT batch_size = ctx.Input(1)->NumElements() - 1; + auto rank = ctx.Input(0)->NumElements(); + int shift = (rank == 2) ? 0 : 1; + indiceT num_rows = *(static_cast(ctx.Input(0)->GetData()) + shift); + indiceT num_cols = *(static_cast(ctx.Input(0)->GetData()) + shift + 1); + indiceT *batch_ptrs = static_cast(ctx.Input(1)->GetData()); + indiceT *row_ptrs = static_cast(ctx.Input(2)->GetData()); + indiceT *col_ind = static_cast(ctx.Input(3)->GetData()); + valueT *values = static_cast(ctx.Input(4)->GetData()); + auto output = ctx.Output(0); + auto output_shape = output->GetTensorShape(); + if (rank == 2) { + output_shape->SetDimSizes({num_rows, num_cols}); + } else { + output_shape->SetDimSizes({batch_size, num_rows, num_cols}); + } + output->SetTensorShape(output_shape.get()); + valueT *y_data = static_cast(ctx.Output(0)->GetData()); + // use multi-thread + uint32_t min_core = 1; + uint32_t max_core = std::max(min_core, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + max_core = std::min(max_core, (uint32_t)batch_size); + if (max_core == 0) { + KERNEL_LOG_ERROR("Max core num cannot be zero."); + return KERNEL_STATUS_PARAM_INVALID; + } + auto shard = [&](int64_t start, int64_t end) { + for (int64_t batch_idx = start; batch_idx < end; batch_idx++) { + const int64_t dense_offset = batch_idx * num_rows * num_cols; + for (int64_t i = 0; i < num_rows * num_cols; ++i) { + y_data[dense_offset + i] = 0; + } + const int64_t csr_batch_offset = batch_ptrs[batch_idx]; + for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) { + const int64_t row_offset = batch_idx * (num_rows + 1) + row_idx; + const int64_t col_begin = row_ptrs[row_offset]; + const int64_t col_end = row_ptrs[row_offset + 1]; + for (int64_t i = col_begin; i < col_end; ++i) { + const int64_t col_idx = col_ind[csr_batch_offset + i]; + y_data[dense_offset + (row_idx * num_cols) + col_idx] = values[csr_batch_offset + i]; + } + } + } + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, batch_size / max_core, shard), + "CSRSparseMatrixToDense Compute failed."); + return KERNEL_STATUS_OK; +} + +// register the opetaor +REGISTER_CPU_KERNEL(CSRSparseMatrixToDense, CSRSparseMatrixToDenseCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.h new file mode 100644 index 00000000000..178865f58f3 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_dense.h @@ -0,0 +1,35 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_CSR_SPARSE_MATRIX_TO_DENSE_H_ +#define AICPU_KERNELS_NORMALIZED_CSR_SPARSE_MATRIX_TO_DENSE_H_ + +#include "Eigen/Core" +#include "cpu_ops_kernel.h" + +namespace aicpu { + +class CSRSparseMatrixToDenseCpuKernel : public CpuKernel { + public: + ~CSRSparseMatrixToDenseCpuKernel() = default; + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t DoCompute(CpuKernelContext &ctx); +}; + +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc new file mode 100644 index 00000000000..9d8f6ddcf62 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc @@ -0,0 +1,229 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "csr_sparse_matrix_to_sparse_tensor.h" + +#include +#include + +#include "cpu_kernel_utils.h" +#include "utils/kernel_util.h" + +namespace { +const uint32_t kInputNum = 5; +const uint32_t kOutputNum = 3; +const char *CSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor"; +// when input data size is more than kParallelDataNum, use Parallel func +const int64_t kParallelDataNum = 4; +const int64_t kParallelDataNumMid = 32; +const int DIM2 = 2; +const int DIM3 = 3; +} // namespace + +namespace aicpu { +uint32_t CSRSparseMatrixToSparseTensorCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CSRSparseMatrixToSparseTensor normal check failed."); + Tensor *x_dense_shape = ctx.Input(0); + Tensor *x_batch_pointers = ctx.Input(1); + Tensor *x_row_pointers = ctx.Input(2); + Tensor *x_col_indices = ctx.Input(3); + Tensor *x_values = ctx.Input(4); + + const int rank = x_dense_shape->NumElements(); + if (rank != DIM2 && rank != DIM3) { + KERNEL_LOG_ERROR("CSR SparseMatrix must have rank 2 or 3."); + return KERNEL_STATUS_PARAM_INVALID; + } + + auto x_row_pointers_shape = x_row_pointers->GetTensorShape(); + auto x_col_indices_shape = x_col_indices->GetTensorShape(); + auto x_values_shape = x_values->GetTensorShape(); + if (x_col_indices_shape->NumElements() != x_values_shape->NumElements()) { + KERNEL_LOG_ERROR("Tensor x_col_indices&x_values's ranks mismatch."); + return KERNEL_STATUS_PARAM_INVALID; + } + + auto x_dense_shape_data_type = x_dense_shape->GetDataType(); + auto x_batch_pointers_data_type = x_batch_pointers->GetDataType(); + auto x_row_pointers_data_type = x_row_pointers->GetDataType(); + auto x_col_indices_data_type = x_col_indices->GetDataType(); + if (x_col_indices_data_type != DT_INT32 && x_col_indices_data_type != DT_INT64) { + KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel data type [%s] not support.", + DTypeStr(x_col_indices_data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + + if (x_dense_shape_data_type != x_col_indices_data_type || x_batch_pointers_data_type != x_col_indices_data_type || + x_row_pointers_data_type != x_col_indices_data_type) { + KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel data type mismatch."); + return KERNEL_STATUS_PARAM_INVALID; + } + + auto x_values_data_type = x_values->GetDataType(); + + uint32_t status; + switch (x_col_indices_data_type) { + case DT_INT32: + switch (x_values_data_type) { + case DT_FLOAT: + status = ComputeKernel(ctx); + break; + case DT_DOUBLE: + status = ComputeKernel(ctx); + break; + case DT_COMPLEX64: + status = ComputeKernel >(ctx); + break; + case DT_COMPLEX128: + status = ComputeKernel >(ctx); + break; + default: + KERNEL_LOG_ERROR( + "CSRSparseMatrixToSparseTensor kernel data type [%s] not " + "support.", + DTypeStr(x_values_data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + break; + case DT_INT64: + switch (x_values_data_type) { + case DT_FLOAT: + status = ComputeKernel(ctx); + break; + case DT_DOUBLE: + status = ComputeKernel(ctx); + break; + case DT_COMPLEX64: + status = ComputeKernel >(ctx); + break; + case DT_COMPLEX128: + status = ComputeKernel >(ctx); + break; + default: + KERNEL_LOG_ERROR( + "CSRSparseMatrixToSparseTensor kernel data type [%s] not " + "support.", + DTypeStr(x_values_data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + break; + default: + KERNEL_LOG_ERROR("data type of indices is not int32 or int64"); + return KERNEL_STATUS_PARAM_INVALID; + } + + if (status != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("CSRSparseMatrixToSparseTensor kernel compute failed."); + return KERNEL_STATUS_PARAM_INVALID; + } + + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTensorCpuKernel); + +template +uint32_t CSRSparseMatrixToSparseTensorCpuKernel::ComputeKernel(CpuKernelContext &ctx) { + auto x_dense_shape = ctx.Input(0); + auto x_dense_shape_ptr = static_cast(x_dense_shape->GetData()); + auto dense_shape_ptr = static_cast(ctx.Output(2)->GetData()); + auto values_ptr = static_cast(ctx.Output(1)->GetData()); + auto x_values = ctx.Input(4); + auto x_values_ptr = static_cast(x_values->GetData()); + + // Copy the SparseTensor's dense_shape and values from the CSRSparseMatrix. + for (int64_t i = 0; i < x_dense_shape->GetTensorShape()->NumElements(); i++) { + dense_shape_ptr[i] = x_dense_shape_ptr[i]; + } + for (int64_t i = 0; i < x_values->GetTensorShape()->NumElements(); i++) { + values_ptr[i] = x_values_ptr[i]; + } + + const uint32_t batch_size = ctx.Input(1)->NumElements() - 1; + if (batch_size >= kParallelDataNum) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + + if (batch_size <= kParallelDataNumMid) { + max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores + } + + if (max_core_num > batch_size) { + max_core_num = batch_size; + } + + auto sharder = [&](int64_t batch_begin, int64_t batch_end) { + SpecialCompute(batch_begin, batch_end, ctx); + }; + + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0."); + } + + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, batch_size / max_core_num, sharder), + "CSRSparseMatrixToSparseTensor Compute failed."); + } else { + SpecialCompute(0, batch_size, ctx); + } + + return KERNEL_STATUS_OK; +} + +template +void CSRSparseMatrixToSparseTensorCpuKernel::SpecialCompute(int64_t batch_begin, int64_t batch_end, + CpuKernelContext &ctx) { + auto x_dense_shape = ctx.Input(0); + const int rank = x_dense_shape->NumElements(); + auto x_dense_shape_ptr = static_cast(x_dense_shape->GetData()); + const int64_t num_rows = x_dense_shape_ptr[(rank == DIM2) ? 0 : 1]; + auto x_batch_pointers_ptr = static_cast(ctx.Input(1)->GetData()); + auto x_row_pointers_ptr = static_cast(ctx.Input(2)->GetData()); + auto x_col_indices_ptr = static_cast(ctx.Input(3)->GetData()); + + for (int64_t batch_idx = batch_begin; batch_idx < batch_end; ++batch_idx) { + const int64_t batch_offset = x_batch_pointers_ptr[batch_idx]; + + for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) { + int64_t row_offset = batch_idx * (num_rows + 1) + row_idx; + + // The column indices of the current row lie in the range: + // [x_row_pointers_ptr[row_offset], x_row_pointer_ptr[row_offset + 1]] + const int64_t col_begin = x_row_pointers_ptr[row_offset]; + const int64_t col_end = x_row_pointers_ptr[row_offset + 1]; + for (int64_t i = col_begin; i < col_end; ++i) { + const int64_t col_idx = x_col_indices_ptr[batch_offset + i]; + const int64_t indices_offset = rank * (batch_offset + i); + IndicesCompute(ctx, indices_offset, batch_idx, row_idx, col_idx); + } + } + } +} + +template +void CSRSparseMatrixToSparseTensorCpuKernel::IndicesCompute(CpuKernelContext &ctx, int64_t indices_offset, + const int64_t batch_idx, const int64_t row_idx, + const int64_t col_idx) { + const int rank = ctx.Input(0)->NumElements(); + auto indices_ptr = static_cast(ctx.Output(0)->GetData()); + if (rank == DIM2) { + indices_ptr[indices_offset] = row_idx; + indices_ptr[++indices_offset] = col_idx; + } else { // rank == 3 + indices_ptr[indices_offset] = batch_idx; + indices_ptr[++indices_offset] = row_idx; + indices_ptr[++indices_offset] = col_idx; + } +} +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.h new file mode 100644 index 00000000000..d17e07f185b --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_CSRSPARSEMATRIXTOSPARSETENSOR_H_ +#define AICPU_KERNELS_NORMALIZED_CSRSPARSEMATRIXTOSPARSETENSOR_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class CSRSparseMatrixToSparseTensorCpuKernel : public CpuKernel { + public: + ~CSRSparseMatrixToSparseTensorCpuKernel() = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t ComputeKernel(CpuKernelContext &ctx); + + template + void SpecialCompute(int64_t batch_begin, int64_t batch_end, CpuKernelContext &ctx); + + template + void IndicesCompute(CpuKernelContext &ctx, int64_t indices_offset, const int64_t batch_idx, const int64_t row_idx, + const int64_t col_idx); +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc new file mode 100644 index 00000000000..da8d969b5d7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cumprod.h" + +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" + +namespace { +const uint32_t kCumprodInputNum = 2; +const uint32_t kCumprodOutputNum = 1; +const int64_t paralled_data_size = 512 * 1024; +const char *kCumprod = "Cumprod"; +#define CUMPROD_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = CumprodCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Cumprod kernel compute failed."); \ + return result; \ + } \ + break; \ + } +} // namespace + +namespace aicpu { +uint32_t CumprodCpuKernel::Compute(CpuKernelContext &ctx) { + // check params + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kCumprodInputNum, kCumprodOutputNum), "[%s] check input and output failed.", + kCumprod); + // parse params + KERNEL_HANDLE_ERROR(CumprodCheck(ctx), "[%s] check params failed.", kCumprod); + auto data_type = ctx.Input(0)->GetDataType(); + switch (data_type) { + CUMPROD_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx) + CUMPROD_COMPUTE_CASE(DT_FLOAT, float, ctx) + CUMPROD_COMPUTE_CASE(DT_DOUBLE, double, ctx) + CUMPROD_COMPUTE_CASE(DT_INT8, int8_t, ctx) + CUMPROD_COMPUTE_CASE(DT_INT16, int16_t, ctx) + CUMPROD_COMPUTE_CASE(DT_INT32, int32_t, ctx) + CUMPROD_COMPUTE_CASE(DT_INT64, int64_t, ctx) + CUMPROD_COMPUTE_CASE(DT_UINT8, uint8_t, ctx) + CUMPROD_COMPUTE_CASE(DT_UINT16, uint16_t, ctx) + CUMPROD_COMPUTE_CASE(DT_UINT32, uint32_t, ctx) + CUMPROD_COMPUTE_CASE(DT_UINT64, uint64_t, ctx) + CUMPROD_COMPUTE_CASE(DT_COMPLEX64, std::complex, ctx) + CUMPROD_COMPUTE_CASE(DT_COMPLEX128, std::complex, ctx) + default: + KERNEL_LOG_ERROR("Cumprod kernel data type [%s] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} +uint32_t CumprodCpuKernel::CumprodCheck(CpuKernelContext &ctx) { + KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed."); + KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.") + KERNEL_CHECK_NULLPTR(ctx.GetAttr("exclusive"), KERNEL_STATUS_PARAM_INVALID, "get exclusive failed."); + KERNEL_CHECK_NULLPTR(ctx.GetAttr("reverse"), KERNEL_STATUS_PARAM_INVALID, "get reverse failed."); + KERNEL_CHECK_FALSE((ctx.Input(1)->GetDataType() == DT_INT32 || ctx.Input(1)->GetDataType() == DT_INT64), + KERNEL_STATUS_PARAM_INVALID, + "Data type of axis is not support, axis data type is [%u], only support int32 or int64.", + ctx.Input(1)->GetDataType()); + KERNEL_CHECK_FALSE(ctx.Input(1)->NumElements() == 1, KERNEL_STATUS_PARAM_INVALID, "axis is out of shape") + auto axis_data = reinterpret_cast(ctx.Input(1)->GetData()); + int32_t axis = *axis_data; + KERNEL_CHECK_FALSE((axis < ctx.Input(0)->GetTensorShape()->GetDims()), KERNEL_STATUS_PARAM_INVALID, + "axis is larger than input dims - 1"); + KERNEL_CHECK_FALSE((axis >= -ctx.Input(0)->GetTensorShape()->GetDims()), KERNEL_STATUS_PARAM_INVALID, + "axis is lower than -input dims"); + std::vector shape_input = ctx.Input(0)->GetTensorShape()->GetDimSizes(); + std::vector shape_output = ctx.Output(0)->GetTensorShape()->GetDimSizes(); + KERNEL_CHECK_FALSE((shape_input.size() != 0), KERNEL_STATUS_PARAM_INVALID, + "Input must be at least rank 1, got [%zu].", shape_input.size()) + KERNEL_CHECK_FALSE((shape_input.size() == shape_output.size()), KERNEL_STATUS_PARAM_INVALID, + "The output shape size should be same as the output shape size") + return KERNEL_STATUS_OK; +} +template +uint32_t CumprodCpuKernel::CumprodCompute(CpuKernelContext &ctx) { + auto input_data = reinterpret_cast(ctx.Input(0)->GetData()); + auto axis_data = reinterpret_cast(ctx.Input(1)->GetData()); + int32_t axis = *axis_data; + bool exclusive = ctx.GetAttr("exclusive")->GetBool(); + bool reverse = ctx.GetAttr("reverse")->GetBool(); + auto output_data = reinterpret_cast(ctx.Output(0)->GetData()); + auto shape = ctx.Input(0)->GetTensorShape(); + const int64_t rank = shape->GetDims(); + if (axis < 0) { + axis += shape->GetDims(); + } + size_t inner = 1; + size_t outer = 1; + size_t depth = 1; + for (int32_t i = 0; i < rank; ++i) { + if (i < axis) { + inner *= shape->GetDimSize(i); + } else if (i > axis) { + outer *= shape->GetDimSize(i); + } else { + depth = shape->GetDimSize(i); + } + } + int64_t data_num = ctx.Input(0)->NumElements(); + int64_t data_size = data_num * sizeof(T); + if (data_size <= paralled_data_size) { + for (size_t outer_index = 0; outer_index < outer; ++outer_index) { + size_t outer_index_adj; + if (reverse) + outer_index_adj = (outer - 1) - outer_index; + else + outer_index_adj = outer_index; + for (size_t inner_index = 0; inner_index < inner; inner_index++) { + auto multiplier = static_cast(1); + size_t inner_index_adj; + if (reverse) + inner_index_adj = (inner - 1) - inner_index; + else + inner_index_adj = inner_index; + for (size_t depth_index = 0; depth_index < depth; depth_index++) { + size_t depth_index_adj; + if (reverse) + depth_index_adj = (depth - 1) - depth_index; + else + depth_index_adj = depth_index; + size_t index = outer_index_adj; + index += inner_index_adj * depth * outer; + index += depth_index_adj * outer; + if (exclusive) { + output_data[index] = multiplier; + multiplier *= input_data[index]; + } else { + multiplier *= input_data[index]; + output_data[index] = multiplier; + } + } + } + } + } else { + auto shard_cumprod = [&](size_t start, size_t ene) { + for (size_t outer_index = 0; outer_index < outer; ++outer_index) { + size_t outer_index_adj; + if (reverse) { + outer_index_adj = (outer - 1) - outer_index; + } else { + outer_index_adj = outer_index; + } + for (size_t inner_index = 0; inner_index < inner; inner_index++) { + auto multiplier = static_cast(1); + size_t inner_index_adj; + if (reverse) { + inner_index_adj = (inner - 1) - inner_index; + } else { + inner_index_adj = inner_index; + } + for (size_t depth_index = 0; depth_index < depth; depth_index++) { + size_t depth_index_adj; + if (reverse) { + depth_index_adj = (depth - 1) - depth_index; + } else { + depth_index_adj = depth_index; + } + size_t index = outer_index_adj; + index += inner_index_adj * depth * outer; + index += depth_index_adj * outer; + if (exclusive) { + output_data[index] = multiplier; + multiplier *= input_data[index]; + } else { + multiplier *= input_data[index]; + output_data[index] = multiplier; + } + } + } + } + }; + uint32_t min_core_num = 1; + size_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + if (max_core_num > outer) { + max_core_num = outer; + } + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0"); + } + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, outer, outer / max_core_num, shard_cumprod), + "Cumprod Compute failed."); + } + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kCumprod, CumprodCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.h new file mode 100644 index 00000000000..795d8d937ad --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_CUMPROD_H_ +#define AICPU_KERNELS_NORMALIZED_CUMPROD_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class CumprodCpuKernel : public CpuKernel { + public: + CumprodCpuKernel() = default; + ~CumprodCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + uint32_t CumprodCheck(CpuKernelContext &ctx); + + template + uint32_t CumprodCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.cc index d341ec3f987..514f9bbe3ba 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.cc @@ -208,4 +208,4 @@ uint32_t CumulativeLogsumexpCpuKernel::CumulativeLogsumexpCompute(CpuKernelConte return KERNEL_STATUS_OK; } REGISTER_CPU_KERNEL(KCumulativeLogsumexp, CumulativeLogsumexpCpuKernel); -} // namespace aicpu \ No newline at end of file +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.h index 3fa4faf7534..5143023a050 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumulativelogsumexp.h @@ -35,4 +35,4 @@ class CumulativeLogsumexpCpuKernel : public CpuKernel { uint32_t CumulativeLogsumexpCompute(CpuKernelContext &ctx); }; } // namespace aicpu -#endif +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.cc new file mode 100644 index 00000000000..e09f39a7b4a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.cc @@ -0,0 +1,166 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "shuffle_channel.h" + +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include + +namespace { +const uint32_t kInputNum = 1; +const uint32_t kOutputNum = 1; +const char *kShuffleChannel = "ShuffleChannel"; +const int64_t minDimSize = 3; + +#define SHUFFLE_CHANNEL_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = ShuffleChannelCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Shuffle Channel kernel compute failed."); \ + return result; \ + } \ + break; \ + } +} // namespace + +namespace aicpu { +uint32_t ShuffleChannelCpuKernel::Compute(CpuKernelContext &ctx) { + // check params + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "ShuffleChannel check input and output number failed."); + KERNEL_HANDLE_ERROR(ShuffleChannelParamCheck(ctx), "ShuffleChannel check params failed."); + auto data_type = ctx.Input(0)->GetDataType(); + switch (data_type) { + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT8, int8_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT16, int16_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT32, int32_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_INT64, int64_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT8, uint8_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT16, uint16_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT32, uint32_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_UINT64, uint64_t, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx) + SHUFFLE_CHANNEL_COMPUTE_CASE(DT_FLOAT, float, ctx) + default: + KERNEL_LOG_ERROR("Shuffle kernel data type [%s] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + return KERNEL_STATUS_OK; +} + +uint32_t ShuffleChannelCpuKernel::ShuffleChannelParamCheck(CpuKernelContext &ctx) { + // the non null of input_0, input_1, output has been verified in NormalCheck + Tensor *input = ctx.Input(0); + auto group = 1; + if (ctx.GetAttr("group")) { + group = ctx.GetAttr("group")->GetInt(); + } + int64_t c = input->GetTensorShape()->GetDimSize(1); + KERNEL_CHECK_FALSE(input->GetTensorShape()->GetDims() >= minDimSize, KERNEL_STATUS_PARAM_INVALID, + "ShuffleChannel expect input with > 2 dims.") + KERNEL_CHECK_FALSE(group > 0, KERNEL_STATUS_PARAM_INVALID, "Number of groups to divide channels in must be positive.") + KERNEL_CHECK_FALSE((c % group) == 0, KERNEL_STATUS_PARAM_INVALID, "Number of channels must be divisible by groups") + KERNEL_LOG_DEBUG( + "ShuffleChannelCpuKernel[%s], input: size[%llu];" + "output: size[%llu].", + ctx.GetOpType().c_str(), input->GetDataSize(), ctx.Output(0)->GetDataSize()); + + return KERNEL_STATUS_OK; +} + +template +uint32_t ShuffleChannelCpuKernel::ShuffleChannelCompute(CpuKernelContext &ctx) { + Tensor *input = ctx.Input(0); + Tensor *output = ctx.Output(0); + auto group = 1; + if (ctx.GetAttr("group")) { + group = ctx.GetAttr("group")->GetInt(); + } + + auto shape = input->GetTensorShape(); + int64_t b = shape->GetDimSize(0); + int64_t c = shape->GetDimSize(1); + int64_t dims = shape->GetDims(); + + if (group == 0) { + KERNEL_LOG_ERROR("group divided can not be zero"); + return KERNEL_STATUS_PARAM_INVALID; + } + int64_t oc = c / group; + + auto out = reinterpret_cast(output->GetData()); + auto in = reinterpret_cast(input->GetData()); + int64_t loc_out = 0; + int64_t loc_in = 0; + int64_t area = 1; + for (int64_t i = 2; i < dims; i++) { + area = area * shape->GetDimSize(i); + } + /* + view the shape to n g c/g h*w,and transpose dim 1 and dim 2 + */ + std::vector temp_shape; + temp_shape.push_back(b); + temp_shape.push_back(group); + temp_shape.push_back(oc); + temp_shape.push_back(area); + std::vector temp_loc = {0, 0, 0, 0}; + bool can_plus = false; + int64_t dim = 0; + while (true) { + for (dim = 0; dim <= minDimSize; dim++) { + if (dim == minDimSize) { + loc_in = loc_in + temp_loc[dim] * 1; + loc_out = loc_out + temp_loc[dim] * 1; + } else if (dim == (minDimSize - 1)) { + loc_in = loc_in + temp_loc[dim] * area; + loc_out = loc_out + temp_loc[1] * area; + } else if (dim == 1) { + loc_in = loc_in + temp_loc[dim] * oc * area; + loc_out = loc_out + temp_loc[minDimSize - 1] * group * area; + } else if (dim == 0) { + loc_in = loc_in + temp_loc[dim] * group * oc * area; + loc_out = loc_out + temp_loc[dim] * group * oc * area; + } + } + *(out + loc_out) = *(in + loc_in); + loc_in = 0; + loc_out = 0; + can_plus = false; + for (dim = 0; dim <= minDimSize; dim++) { + if (temp_loc[dim] < (temp_shape[dim] - 1)) { + can_plus = true; + break; + } + } + if (!can_plus) { + break; + } + for (dim = minDimSize; dim >= 0; dim--) { + if (temp_loc[dim] == (temp_shape[dim] - 1)) { + temp_loc[dim] = 0; + } else { + temp_loc[dim] = temp_loc[dim] + 1; + break; + } + } + } + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kShuffleChannel, ShuffleChannelCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.h new file mode 100644 index 00000000000..50fb1c5ea1d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/shuffle_channel.h @@ -0,0 +1,39 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_SHUFFLE_CHANNEL_H_ +#define AICPU_KERNELS_NORMALIZED_SHUFFLE_CHANNEL_H_ + +#include "cpu_ops_kernel.h" +#include "utils/bcast.h" + +namespace aicpu { +class ShuffleChannelCpuKernel : public CpuKernel { + public: + ShuffleChannelCpuKernel() = default; + ~ShuffleChannelCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + uint32_t ShuffleChannelParamCheck(CpuKernelContext &ctx); + + template + uint32_t ShuffleChannelCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc index 7209a411d5c..bb5af6efa76 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc @@ -52,9 +52,24 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An kFSEDecodeOpName}; static const std::set kMigrateAicpuKernelOps = {mindspore::kAdaptiveAvgPool2DV1OpName, mindspore::kAdaptiveAvgPool2DGradV1OpName, + mindspore::kBucketizeOpName, mindspore::kCacheSwapTableOpName, + mindspore::kCauchyOpName, + mindspore::kChannelShuffleOpName, + mindspore::kCholeskyGradOpName, + mindspore::kCholeskyInverseOpName, + mindspore::kCholeskySolveOpName, mindspore::kCol2imOpName, + mindspore::kCombinedNonMaxSuppressionOpName, + mindspore::kComplexOpName, + mindspore::kComplexAbsOpName, + mindspore::kConcatOpName, + mindspore::kCosOpName, + mindspore::kCountNonZeroOpName, mindspore::kCumulativeLogsumexpOpName, + mindspore::kCumprodOpName, + mindspore::kCSRSparseMatrixToDenseOpName, + mindspore::kCSRSparseMatrixToSparseTensorOpName, mindspore::kDataFormatVecPermuteOpName, mindspore::kFillOpName, mindspore::kLogMatrixDeterminantOpName, diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 819eeb785d4..23d6d6ce016 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -184,6 +184,20 @@ from .nan_to_num import _nan_to_num_aicpu from .qr import _qr_aicpu from .col2im import _col2im_aicpu from .matrix_solve_ls import _matrix_solve_ls_aicpu +from .cauchy import _cauchy_aicpu +from .bucketize import _bucketize_aicpu +from .channel_shuffle import _channel_shuffle_aicpu +from .choleskygrad import _choleskygrad_aicpu +from .cholesky_inverse import _cholesky_inverse_aicpu +from .cholesky_solve import _cholesky_solve_aicpu +from .combined_non_max_suppression import _combined_non_max_suppression_aicpu +from .complex import _complex_aicpu +from .complex_abs import _complex_abs_aicpu +from .concat import _concat_aicpu +from .cos import _cos_aicpu +from .count_nonzero import _count_nonzero_aicpu +from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu +from .cumprod import _cumprod_aicpu from .exp import _exp_aicpu from .matrix_triangular_solve import _matrix_triangular_solve_aicpu from .maximum_grad_grad import _maximum_grad_grad_aicpu