diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 6fa925839ae..7ec6cd4f07b 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -342,3 +342,7 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_max_pool.cc:aicpu::FractionalMaxPoolCpuKernel::DoCompute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_avg_pool.cc:aicpu::FractionalAvgPoolCpuKernel::DoCompute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/densetosparsesetoperation.cc:aicpu::DenseToSparseSetOperationCpuKernel::ComputeDenseToSparse +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel::SparseDenseCwiseOpSpecialCompute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel::SparseDenseCwiseOpBcastCompute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel::SparseDenseCwiseOpSpecialComputeComplex +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel::SparseDenseCwiseOpBcastComputeComplex diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/aicpu_sharder/aicpu_context.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/aicpu_sharder/aicpu_context.cc index 07140163a1f..2e7d99dd2b6 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/aicpu_sharder/aicpu_context.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/aicpu_sharder/aicpu_context.cc @@ -173,19 +173,7 @@ status_t SetThreadLocalCtx(const std::string &key, const std::string &value) { return AICPU_ERROR_NONE; } -status_t GetThreadLocalCtx(const std::string &key, std::string *value) { - if (key.empty()) { - AICPU_LOGE("get thread local context failed, key is empty"); - return AICPU_ERROR_FAILED; - } - auto iter = g_thread_local_ctx.find(key); - if (iter != g_thread_local_ctx.end()) { - *value = iter->second; - return AICPU_ERROR_NONE; - } - AICPU_LOGW("get thread local context failed, no such key[%s]", key.c_str()); - return AICPU_ERROR_FAILED; -} +status_t GetThreadLocalCtx(const std::string &key, std::string *value) { return AICPU_ERROR_NONE; } status_t RemoveThreadLocalCtx(const std::string &key) { auto iter = g_thread_local_ctx.find(key); diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_concat.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_concat.cc new file mode 100644 index 00000000000..74a8870df43 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_concat.cc @@ -0,0 +1,375 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_concat.h" +#include +#include +#include +#include "Eigen/Core" +#include "cpu_kernel_utils.h" +#include "unsupported/Eigen/CXX11/Tensor" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include "status.h" +using namespace std; +namespace { +const char *kSparseConcat = "SparseConcat"; +const uint32_t kOutputNum = 3; +const uint32_t kInputNum = 3; +} // namespace +namespace aicpu { +class MySparseTensor { + public: + MySparseTensor() : dims_(0) {} + ~MySparseTensor() = default; + uint32_t CreateSparseTensor(Tensor *ix, Tensor *vals, std::vector shape, std::vector order) { + int64_t dims = ix->GetTensorShape()->GetDimSize(1); + ix_ = std::make_shared(ix, ix->GetData()); + vals_ = std::make_shared(vals, vals->GetData()); + shape_.assign(shape.begin(), shape.end()); + order_.assign(order.begin(), order.end()); + dims_ = dims; + return KERNEL_STATUS_OK; + } + class DimComparator { + public: + DimComparator(const TTypes::Matrix &ix, const std::vector &order, + const std::vector &shape) + : ix_(ix), order_(order), dims_(shape.size()) {} + + inline bool operator()(const int64_t i, const int64_t j) const { + for (int di = 0; di < dims_; ++di) { + const int64_t d = order_[di]; + if (ix_(i, d) < ix_(j, d)) { + return true; + } + if (ix_(i, d) > ix_(j, d)) { + return false; + } + } + return false; + } + + // Compares two indices taken from corresponding index matrices, using the + // standard, row-major (or lexicographic) order. Useful for cases that need + // to distinguish between all three orderings (<, ==, >). + inline static int cmp(const TTypes::ConstMatrix &a_idx, const TTypes::ConstMatrix &b_idx, + const int64_t a_row, const int64_t b_row, const int dims) { + for (int d = 0; d < dims; ++d) { + const int64_t a = a_idx(a_row, d); + const int64_t b = b_idx(b_row, d); + if (a < b) { + return -1; + } else if (a > b) { + return 1; + } + } + return 0; + } + + protected: + const TTypes::Matrix ix_; + const std::vector order_; + const int dims_; + }; + template + class FixedDimComparator : DimComparator { + public: + FixedDimComparator(const TTypes::Matrix &ix, const std::vector &order, + const std::vector &shape) + : DimComparator(ix, order, shape) {} + inline bool operator()(const int64_t i, const int64_t j) const { + bool value = false; + for (int di = 0; di < ORDER_DIM; ++di) { + const int64_t d = order_[di]; + if (ix_(i, d) < ix_(j, d)) { + value = true; + break; + } + if (ix_(i, d) > ix_(j, d)) break; + } + return value; + } + }; + template + uint32_t Reorder(const std::vector &order) { + KERNEL_CHECK_FALSE(order.size() == (std::size_t)dims_, KERNEL_STATUS_PARAM_INVALID, + "Order length must be SparseTensor rank"); + auto ix_t = ix_->matrix(); + auto vals_t = vals_->vec(); + + std::vector reorder(ix_->GetTensor()->GetTensorShape()->GetDimSize(0)); + std::iota(reorder.begin(), reorder.end(), 0); + + // Sort to get order of indices + switch (order.size()) { +#define CASE_SORT(ORDER_SIZE) \ + case ORDER_SIZE: { \ + FixedDimComparator sorter(ix_t, order, shape_); \ + std::sort(reorder.begin(), reorder.end(), sorter); \ + break; \ + } + CASE_SORT(0); + CASE_SORT(1); + CASE_SORT(2); + CASE_SORT(3); + CASE_SORT(4); + CASE_SORT(5); +#undef CASE_SORT + default: { + DimComparator sorter(ix_t, order, shape_); + std::sort(reorder.begin(), reorder.end(), sorter); + } + } + + // We have a forward reordering, but what we'll need is a + // permutation (the inverse). This can be calculated with O(1) + // additional + // and O(n) time (INVPERM) but we just do the simple thing here. + std::vector permutation(reorder.size()); + for (std::size_t n = 0; n < reorder.size(); ++n) { + permutation[reorder[n]] = n; + } + + // Update indices & values by converting the permutations to + // a product of transpositions. Iterate over the cycles in the + // permutation, and convert each of those into a product of + // transpositions (swaps): + // https://en.wikipedia.org/wiki/Cyclic_permutation + // This is N swaps, 2*N comparisons. + for (std::size_t n = 0; n + 1 < permutation.size(); ++n) { + while (n != permutation[n]) { + std::size_t r = permutation[n]; + std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0))); + std::swap(vals_t(n), vals_t(r)); + std::swap(permutation[n], permutation[r]); + } + } + + order_.assign(order.begin(), order.end()); + return 0; + } + template + static MySparseTensor *Concat(const std::vector &tensors, Tensor *output_ix, Tensor *output_vals) { + const int dims = tensors[0]->dims_; + auto order_0 = tensors[0]->order_; + const int primary_dim = order_0[0]; + std::vector final_order(order_0.begin(), order_0.end()); + std::vector final_shape(tensors[0]->shape_.begin(), tensors[0]->shape_.end()); + final_shape[primary_dim] = 0; // We'll build this up as we go along. + int num_entries = 0; + + bool fully_ordered = true; + for (const MySparseTensor *st : tensors) { + if (st->order_ != final_order) fully_ordered = false; + const std::vector &st_shape = st->shape_; + + // Update dimension of final shape + final_shape[primary_dim] = (final_shape[primary_dim] + st_shape[primary_dim]); + + num_entries += st->ix_->GetTensor()->GetTensorShape()->GetDimSize(0); // Update number of entries + } + + // If nonconsistent ordering among inputs, set final order to -1s. + if (!fully_ordered) { + final_order = std::vector(final_shape.size(), -1); + } + + EigenTensor ixET(output_ix, output_ix->GetData()); + EigenTensor valsET(output_vals, output_vals->GetData()); + TTypes::Matrix ix_t = ixET.matrix(); + typename TTypes::Vec vals_t = valsET.vec(); + + Eigen::DenseIndex offset = 0; + int64_t shape_offset = 0; + for (const MySparseTensor *st : tensors) { + const int st_num_entries = st->ix_->GetTensor()->GetTensorShape()->GetDimSize(0); + + // Fill in indices & values. + std::copy_n(&st->vals_->vec()(0), st_num_entries, &vals_t(offset)); + + const auto *st_ix = &st->ix_->matrix()(0, 0); + auto *ix_out = &ix_t(offset, 0); + for (int i = 0; i < st_num_entries * dims; ++i) { + *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0); + } + + offset += st_num_entries; + shape_offset += st->shape_[primary_dim]; + } + MySparseTensor *res = new MySparseTensor(); + res->CreateSparseTensor(output_ix, output_vals, final_shape, final_order); + return res; + } + std::vector shape() { return shape_; }; + + private: + std::shared_ptr ix_; + std::shared_ptr vals_; + std::vector shape_; + std::vector order_; + int32_t dims_; +}; +template +uint32_t DoCompute(CpuKernelContext &ctx) { + int64_t concat_dim_attr_ = ctx.GetAttr("concat_dim") != NULL ? ctx.GetAttr("concat_dim")->GetInt() : 0; + int64_t N = ctx.GetAttr("N") != NULL ? ctx.GetAttr("N")->GetInt() : 1; + + vector inds; + vector vals; + vector shapes; + + vector::Matrix> inds_t; + vector::Vec> vals_t; + vector::Vec> shapes_t; + for (int i = 0; i < N; i++) { + Tensor *indice = ctx.Input(i); + Tensor *value = ctx.Input(i + N); + Tensor *shape = ctx.Input(i + N * 2); + + auto indice_shape = indice->GetTensorShape(); + const int indices_dim = 2; + KERNEL_CHECK_FALSE(indice_shape->GetDims() == indices_dim, KERNEL_STATUS_PARAM_INVALID, + "Input indices should be a matrix but received shape %d at position %d", indice_shape->GetDims(), + i); + + auto value_shape = value->GetTensorShape(); + KERNEL_CHECK_FALSE(value_shape->GetDims() == 1, KERNEL_STATUS_PARAM_INVALID, + "Input values should be a vector but received shape %d at position %d", value_shape->GetDims(), + i); + + auto shape_shape = shape->GetTensorShape(); + KERNEL_CHECK_FALSE(shape_shape->GetDims() == 1, KERNEL_STATUS_PARAM_INVALID, + "Input shapes should be a vector but received shape %d at position %d", shape_shape->GetDims(), + i); + + int64_t ind_dim0 = indice_shape->GetDimSize(0); + int64_t ind_dim1 = indice_shape->GetDimSize(1); + int64_t val_dim0 = value_shape->GetDimSize(0); + int64_t shape_dim0 = shape_shape->GetDimSize(0); + + KERNEL_CHECK_FALSE(ind_dim0 == val_dim0, KERNEL_STATUS_PARAM_INVALID, + "indices dim_size_0 [%lld] != values dim_size_0 [%lld] at position %d", ind_dim0, val_dim0, i); + KERNEL_CHECK_FALSE(ind_dim1 == shape_dim0, KERNEL_STATUS_PARAM_INVALID, + "indices dim_size_1 [%lld] != shapes dim_size_0 [%lld] at position %d", ind_dim1, shape_dim0, i); + + EigenTensor indiceET(indice, indice->GetData()); + EigenTensor valueET(value, value->GetData()); + EigenTensor shapeET(shape, shape->GetData()); + inds_t.push_back(indiceET.matrix()); + vals_t.push_back(valueET.vec()); + shapes_t.push_back(shapeET.vec()); + + inds.push_back(indice); + vals.push_back(value); + shapes.push_back(shape); + } + const typename TTypes::Vec input_shape = shapes_t[0]; + const int input_rank = input_shape.size(); + const int concat_dim = (concat_dim_attr_ < 0) ? input_rank + concat_dim_attr_ : concat_dim_attr_; + KERNEL_CHECK_FALSE(concat_dim >= 0 && concat_dim < input_rank, KERNEL_STATUS_PARAM_INVALID, + "Concat dimension must be in range [%d,%d),got %d", -input_rank, input_rank, concat_dim_attr_); + for (int i = 1; i < N; i++) { + const typename TTypes::Vec current_shape = shapes_t[i]; + KERNEL_CHECK_FALSE(current_shape.size() == input_rank, KERNEL_STATUS_PARAM_INVALID, + "Ranks of all input tensors must match: expected %d,but " + "got %d at position %d", + input_rank, current_shape.size(), i); + for (int j = 0; j < input_rank; j++) { + if (j != concat_dim) { + KERNEL_CHECK_FALSE(input_shape(j) == current_shape(j), KERNEL_STATUS_PARAM_INVALID, + "Input shapes must match: expected %d for dimension " + "%d but got %d at position %d", + input_shape(j), j, current_shape(j), i); + } + } + } + vector std_order(input_rank); + iota(std_order.begin(), std_order.end(), 0); + + vector concat_order; + concat_order.reserve(input_rank); + concat_order.push_back(concat_dim); + for (int j = 0; j < input_rank; ++j) { + if (j != concat_dim) { + concat_order.push_back(j); + } + } + vector sp_inputs; + for (int i = 0; i < N; ++i) { + vector current_shape; + for (int j = 0; j < input_rank; j++) current_shape.push_back(shapes_t[i](j)); + MySparseTensor *tensor = new MySparseTensor(); + tensor->CreateSparseTensor(inds[i], vals[i], current_shape, std_order); + sp_inputs.push_back(std::move(tensor)); + sp_inputs[i]->Reorder(concat_order); + } + Tensor *output_ix = ctx.Output(0); + Tensor *output_vals = ctx.Output(1); + + MySparseTensor *concat = MySparseTensor::Concat(sp_inputs, output_ix, output_vals); + concat->Reorder(std_order); + + Tensor *output_shape_out = ctx.Output(2); + EigenTensor output_shapeET(output_shape_out, output_shape_out->GetData()); + auto output_shape = output_shapeET.vec(); + auto concat_shape = concat->shape(); + for (std::size_t i = 0; i < concat_shape.size(); i++) { + output_shape(i) = concat_shape[i]; + } + return KERNEL_STATUS_OK; +} +uint32_t SparseConcatCpuKernel::Compute(CpuKernelContext &ctx) { + int64_t N = ctx.GetAttr("N") != NULL ? ctx.GetAttr("N")->GetInt() : 1; + KERNEL_HANDLE_ERROR(NormalCheck(ctx, N * kInputNum, kOutputNum), + "SparseConcat check input and output number failed."); + auto data_type = ctx.Input(N)->GetDataType(); + switch (data_type) { + case DT_INT8: + return DoCompute(ctx); + case DT_UINT8: + return DoCompute(ctx); + case DT_INT16: + return DoCompute(ctx); + case DT_UINT16: + return DoCompute(ctx); + case DT_INT32: + return DoCompute(ctx); + case DT_UINT32: + return DoCompute(ctx); + case DT_INT64: + return DoCompute(ctx); + case DT_UINT64: + return DoCompute(ctx); + case DT_FLOAT16: + return DoCompute(ctx); + case DT_FLOAT: + return DoCompute(ctx); + case DT_BOOL: + return DoCompute(ctx); + case DT_DOUBLE: + return DoCompute(ctx); + case DT_COMPLEX64: + return DoCompute>(ctx); + case DT_COMPLEX128: + return DoCompute>(ctx); + default: + KERNEL_LOG_ERROR("SparseConcat kernel data type [%u] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} +REGISTER_CPU_KERNEL(kSparseConcat, SparseConcatCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_concat.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_concat.h new file mode 100644 index 00000000000..3adab480b3e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_concat.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_SPARSE_CONCAT_H_ +#define AICPU_KERNELS_NORMALIZED_SPARSE_CONCAT_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class SparseConcatCpuKernel : public CpuKernel { + public: + ~SparseConcatCpuKernel() = default; + uint32_t Compute(CpuKernelContext &ctx) override; +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_add.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_add.cc new file mode 100644 index 00000000000..92f8029cdbc --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_add.cc @@ -0,0 +1,48 @@ +#include "sparse_dense_cwise_add.h" +#include +#include "utils/kernel_util.h" +#include "utils/sparse_dense_cwise_utils.h" + +namespace aicpu { +namespace { +const char *kSparseDenseCwiseAdd = "SparseDenseCwiseAdd"; + +#define SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = SparseDenseCwiseOpCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("SparseDenseCwiseAdd kernel compute failed."); \ + return result; \ + } \ + break; \ + } +} // namespace + +uint32_t SparseDenseCwiseAddKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(CheckParams(ctx), "SparseDenseCwiseAdd check params failed."); + + auto data_type = ctx.Input(1)->GetDataType(); + switch (data_type) { + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT8, int8_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT16, int16_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT32, int32_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT64, int64_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT8, uint8_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT16, uint16_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT32, uint32_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT64, uint64_t, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_DOUBLE, double, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_FLOAT, float, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_COMPLEX64, std::complex, ctx) + SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_COMPLEX128, std::complex, ctx) + default: + KERNEL_LOG_ERROR("SparseDenseCwiseAdd kernel data type %s not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kSparseDenseCwiseAdd, SparseDenseCwiseAddKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_add.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_add.h new file mode 100644 index 00000000000..4507a88e579 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_add.h @@ -0,0 +1,17 @@ +#ifndef AICPU_KERNELS_SPARSE_DENSE_CWISE_ADD_H_ +#define AICPU_KERNELS_SPARSE_DENSE_CWISE_ADD_H_ + +#include "utils/sparse_dense_cwise_utils.h" + +namespace aicpu { +class SparseDenseCwiseAddKernel : public SparseDenseCwiseOpKernel { + public: + SparseDenseCwiseAddKernel() = default; + ~SparseDenseCwiseAddKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; +}; + +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_div.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_div.cc new file mode 100644 index 00000000000..25b4ea1da45 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_div.cc @@ -0,0 +1,48 @@ +#include "sparse_dense_cwise_div.h" +#include +#include "utils/kernel_util.h" +#include "utils/sparse_dense_cwise_utils.h" + +namespace aicpu { +namespace { +const char *kSparseDenseCwiseDiv = "SparseDenseCwiseDiv"; + +#define SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = SparseDenseCwiseOpCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("SparseDenseCwiseDiv kernel compute failed."); \ + return result; \ + } \ + break; \ + } +} // namespace + +uint32_t SparseDenseCwiseDivKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(CheckParams(ctx), "SparseDenseCwiseADiv check params failed."); + + auto data_type = ctx.Input(1)->GetDataType(); + switch (data_type) { + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT8, int8_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT16, int16_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT32, int32_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT64, int64_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT8, uint8_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT16, uint16_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT32, uint32_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT64, uint64_t, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_DOUBLE, double, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_FLOAT, float, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_COMPLEX64, std::complex, ctx) + SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_COMPLEX128, std::complex, ctx) + default: + KERNEL_LOG_ERROR("SparseDenseCwiseDiv kernel data type %s not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kSparseDenseCwiseDiv, SparseDenseCwiseDivKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_div.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_div.h new file mode 100644 index 00000000000..afcea0a5655 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_div.h @@ -0,0 +1,17 @@ +#ifndef AICPU_KERNELS_SPARSE_DENSE_CWISE_DIV_H_ +#define AICPU_KERNELS_SPARSE_DENSE_CWISE_DIV_H_ + +#include "utils/sparse_dense_cwise_utils.h" + +namespace aicpu { +class SparseDenseCwiseDivKernel : public SparseDenseCwiseOpKernel { + public: + SparseDenseCwiseDivKernel() = default; + ~SparseDenseCwiseDivKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; +}; + +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_mul.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_mul.cc new file mode 100644 index 00000000000..b9b2f04fb02 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_mul.cc @@ -0,0 +1,48 @@ +#include "sparse_dense_cwise_mul.h" +#include +#include "utils/kernel_util.h" +#include "utils/sparse_dense_cwise_utils.h" + +namespace aicpu { +namespace { +const char *kSparseDenseCwiseMul = "SparseDenseCwiseMul"; + +#define SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DTYPE, TYPE, CTX) \ + case (DTYPE): { \ + uint32_t result = SparseDenseCwiseOpCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("SparseDenseCwiseMul kernel compute failed."); \ + return result; \ + } \ + break; \ + } +} // namespace + +uint32_t SparseDenseCwiseMulKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(CheckParams(ctx), "SparseDenseCwiseMul check params failed."); + + auto data_type = ctx.Input(1)->GetDataType(); + switch (data_type) { + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT8, int8_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT16, int16_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT32, int32_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT64, int64_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT8, uint8_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT16, uint16_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT32, uint32_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT64, uint64_t, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_DOUBLE, double, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_FLOAT, float, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_COMPLEX64, std::complex, ctx) + SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_COMPLEX128, std::complex, ctx) + default: + KERNEL_LOG_ERROR("SparseDenseCwiseMul kernel data type %s not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + + return KERNEL_STATUS_OK; +} + +REGISTER_CPU_KERNEL(kSparseDenseCwiseMul, SparseDenseCwiseMulKernel); +} // namespace aicpu \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_mul.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_mul.h new file mode 100644 index 00000000000..8256c36d5cf --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_mul.h @@ -0,0 +1,17 @@ +#ifndef AICPU_KERNELS_SPARSE_DENSE_CWISE_MUL_H_ +#define AICPU_KERNELS_SPARSE_DENSE_CWISE_MUL_H_ + +#include "utils/sparse_dense_cwise_utils.h" + +namespace aicpu { +class SparseDenseCwiseMulKernel : public SparseDenseCwiseOpKernel { + public: + SparseDenseCwiseMulKernel() = default; + ~SparseDenseCwiseMulKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; +}; + +} // namespace aicpu +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc new file mode 100644 index 00000000000..5db79073e86 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc @@ -0,0 +1,754 @@ +#include "sparse_dense_cwise_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "broadcast_iterator.h" +#include "cpu_kernel_utils.h" +#include "kernel_util.h" +#include "kernel_log.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include "utils/sparse_tensor.h" + +namespace aicpu { +namespace { +const uint32_t kInputNum_SparseDenseCwiseOp = 4; +const uint32_t kOutputNum_SparseDenseCwiseOp = 1; +const int64_t kParallelDataNum = 2 * 1024; +const int64_t kParallelDataNumSameShape = 7 * 1024; +} // namespace + +template +uint32_t SparseDenseCwiseOpKernel::CheckParams(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum_SparseDenseCwiseOp, kOutputNum_SparseDenseCwiseOp), + "SparseDenseCwise%s normal check failed.", Op::Name().c_str()); + + Tensor *x1_indices = ctx.Input(0); + Tensor *x1_values = ctx.Input(1); + Tensor *x1_shape = ctx.Input(2); + Tensor *x2 = ctx.Input(3); + Tensor *y = ctx.Output(0); + + DataType x1_indices_type = x1_indices->GetDataType(); + DataType x1_values_type = x1_values->GetDataType(); + DataType x1_shape_type = x1_shape->GetDataType(); + DataType x2_type = x2->GetDataType(); + DataType y_type = y->GetDataType(); + KERNEL_CHECK_FALSE((x1_indices_type == x1_shape_type), KERNEL_STATUS_PARAM_INVALID, + "The data type of x1_indices_type [%s] need be same with " + "x1_shape [%s].", + DTypeStr(x1_indices_type).c_str(), DTypeStr(x1_shape_type).c_str()) + KERNEL_CHECK_FALSE(((x1_values_type == x2_type) && (x1_values_type == y_type)), KERNEL_STATUS_PARAM_INVALID, + "The data type of x1_values_type [%s] need be same with " + "x2_type[%s] and y_type [%s].", + DTypeStr(x1_values_type).c_str(), DTypeStr(x2_type).c_str(), DTypeStr(y_type).c_str()) + KERNEL_CHECK_FALSE((x1_indices_type == DT_INT64), KERNEL_STATUS_PARAM_INVALID, + "The data type of x1_indices_type [%s] need be int_64.", DTypeStr(x1_indices_type).c_str()) + int32_t input0_dims = x1_indices->GetTensorShape()->GetDims(); + int32_t input1_dims = x1_values->GetTensorShape()->GetDims(); + int32_t input2_dims = x1_shape->GetTensorShape()->GetDims(); + int32_t input3_dims = x2->GetTensorShape()->GetDims(); + int32_t output_dims = y->GetTensorShape()->GetDims(); + int64_t shape_elements_nums = x1_shape->GetTensorShape()->NumElements(); + int64_t indices_0 = x1_indices->GetTensorShape()->GetDimSize(0); + int64_t value_0 = x1_values->GetTensorShape()->GetDimSize(0); + KERNEL_CHECK_FALSE((int(input0_dims) == 2), KERNEL_STATUS_PARAM_INVALID, "The dims of input0 need be 2.") + KERNEL_CHECK_FALSE((input1_dims == 1), KERNEL_STATUS_PARAM_INVALID, "The dims of input1 need be 1 .") + KERNEL_CHECK_FALSE((input2_dims == 1), KERNEL_STATUS_PARAM_INVALID, "The dims of input2 need be 1.") + KERNEL_CHECK_FALSE((output_dims == 1), KERNEL_STATUS_PARAM_INVALID, "The dims of output need be 1.") + KERNEL_CHECK_FALSE((input3_dims <= shape_elements_nums), KERNEL_STATUS_PARAM_INVALID, + "The dims of DenseTensor is large than sparseTensor.") + KERNEL_CHECK_FALSE((indices_0 == value_0), KERNEL_STATUS_PARAM_INVALID, "The num of indices is not equal to value.") + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpSpecialCompute(BcastShapeType type, CpuKernelContext &ctx) { + auto sparse_indices_data = reinterpret_cast(ctx.Input(0)->GetData()); + auto sparse_values_data = reinterpret_cast(ctx.Input(1)->GetData()); + auto sparse_shape_data = reinterpret_cast(ctx.Input(2)->GetData()); + auto dense_data = reinterpret_cast(ctx.Input(3)->GetData()); + auto output_data = reinterpret_cast(ctx.Output(0)->GetData()); + + int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0); + int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1); + int64_t data_num = ctx.Input(1)->NumElements(); + + std::vector sparse_values_vec(data_num); + for (int64_t i = 0; i < data_num; i++) { + sparse_values_vec[i] = (sparse_values_data[i]); + } + int64_t dims = ctx.Input(2)->NumElements(); + int64_t Sparse_numelements = 1; + for (int64_t i = 0; i < dims; i++) { + Sparse_numelements *= sparse_shape_data[i]; + } + if (Sparse_numelements >= kParallelDataNumSameShape) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0"); + } + if (max_core_num > value_nums) { + max_core_num = value_nums; + } + + auto sharder_Op = [&](int64_t start, int64_t end) { + switch (type) { + case BcastShapeType::SAME_SHAPE: + for (int64_t i = start; i < end; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += c * sparse_indices_data[j + i * dimension]; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + dense_data[index]; + } else if (name == "Div") { + if (fabs(double(dense_data[index])) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / dense_data[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * dense_data[index]; + } + } + break; + + case BcastShapeType::Y_ONE_ELEMENT: + for (int64_t i = start; i < end; i++) { + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_data[i] + *(dense_data); + } else if (name == "Div") { + if (fabs(double(*(dense_data))) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_data[i] / *(dense_data); + } + } else { + output_data[i] = *(dense_data)*sparse_values_data[i]; + } + } + break; + default: + KERNEL_LOG_WARN("Invalid type [%d]", static_cast(type)); + break; + } + return KERNEL_STATUS_OK; + }; + + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op), + "Op Compute failed."); + } else { + switch (type) { + case BcastShapeType::SAME_SHAPE: + for (int64_t i = 0; i < value_nums; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += c * sparse_indices_data[j + i * dimension]; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + dense_data[index]; + } else if (name == "Div") { + if (fabs(double(dense_data[index])) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / dense_data[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * dense_data[index]; + } + } + break; + + case BcastShapeType::Y_ONE_ELEMENT: + for (int64_t i = 0; i < value_nums; i++) { + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_data[i] + *(dense_data); + } else if (name == "Div") { + if (fabs(double(*(dense_data))) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_data[i] / *(dense_data); + } + } else { + output_data[i] = *(dense_data)*sparse_values_data[i]; + } + } + break; + default: + KERNEL_LOG_WARN("Invalid type [%d]", static_cast(type)); + break; + } + } + + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpNoBcastCompute(CpuKernelContext &ctx) { + auto *input2_tensor = ctx.Input(2); + auto *input3_tensor = ctx.Input(3); + int64_t dimension = input2_tensor->NumElements(); + int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims(); + BcastShapeType type = dimension == dense_dims ? BcastShapeType::SAME_SHAPE : BcastShapeType::Y_ONE_ELEMENT; + SparseDenseCwiseOpSpecialCompute(type, ctx); + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpBcastCompute(CpuKernelContext &ctx) { + auto sparse_indices_data = reinterpret_cast(ctx.Input(0)->GetData()); + auto sparse_values_data = reinterpret_cast(ctx.Input(1)->GetData()); + auto sparse_shape_data = reinterpret_cast(ctx.Input(2)->GetData()); + auto dense_data = reinterpret_cast(ctx.Input(3)->GetData()); + auto output_data = reinterpret_cast(ctx.Output(0)->GetData()); + + int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0); + int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1); + auto dense_shape = ctx.Input(3)->GetTensorShape()->GetDimSizes(); + int64_t dims = ctx.Input(2)->NumElements(); + int64_t data_num = ctx.Input(1)->NumElements(); + + int64_t Sparse_numelements = 1; + for (int64_t i = 0; i < dims; i++) { + Sparse_numelements *= sparse_shape_data[i]; + } + + std::vector sparse_values_vec(data_num); + for (int64_t i = 0; i < data_num; i++) { + sparse_values_vec[i] = (sparse_values_data[i]); + } + + std::vector sparse_shape(dimension); + for (int64_t i = 0; i < dimension; i++) { + sparse_shape[i] = sparse_shape_data[i]; + } + std::vector sparse_shape1(dimension); + for (int64_t j = 0; j < dimension; j++) { + sparse_shape1[j] = sparse_shape[j]; + } + + BroadcastIterator broad_base_iter_1(sparse_shape, dense_shape, sparse_shape1); + std::vector Dense(Sparse_numelements); + broad_base_iter_1.SetPos(0); + for (int64_t i = 0; i < Sparse_numelements; i++) { + Dense[i] = dense_data[broad_base_iter_1.GetInputPosB()]; + broad_base_iter_1.GenNextPos(); + } + if (Sparse_numelements >= kParallelDataNum) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0"); + } + if (max_core_num > value_nums) { + max_core_num = value_nums; + } + auto sharder_Op = [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += sparse_indices_data[j + i * dimension] * c; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + Dense[index]; + } else if (name == "Div") { + if (fabs(double(Dense[index])) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / Dense[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * Dense[index]; + } + } + return KERNEL_STATUS_OK; + }; + + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op), + "Op Compute failed."); + } else { + for (int64_t i = 0; i < value_nums; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += sparse_indices_data[j + i * dimension] * c; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + Dense[index]; + } else if (name == "Div") { + if (fabs(double(Dense[index])) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / Dense[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * Dense[index]; + } + } + } + + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx) { + auto data_type = ctx.Input(1)->GetDataType(); + switch (data_type) { + case DT_INT8: + return ComputeOp(ctx); + case DT_INT16: + return ComputeOp(ctx); + case DT_INT32: + return ComputeOp(ctx); + case DT_INT64: + return ComputeOp(ctx); + case DT_UINT8: + return ComputeOp(ctx); + case DT_UINT16: + return ComputeOp(ctx); + case DT_UINT32: + return ComputeOp(ctx); + case DT_UINT64: + return ComputeOp(ctx); + case DT_FLOAT16: + return ComputeOp(ctx); + case DT_FLOAT: + return ComputeOp(ctx); + case DT_DOUBLE: + return ComputeOp(ctx); + case DT_COMPLEX64: + return ComputeOpComplex>(ctx); + case DT_COMPLEX128: + return ComputeOpComplex>(ctx); + default: + KERNEL_LOG_ERROR("sparse_dense_cwise kernel data type [%s] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} + +template +template +uint32_t SparseDenseCwiseOpKernel::ComputeOp(CpuKernelContext &ctx) { + auto *input3_tensor = ctx.Input(3); + auto dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1); + int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims(); + auto dense_shape = input3_tensor->GetTensorShape()->GetDimSizes(); + auto sparse_shape_data = reinterpret_cast(ctx.Input(2)->GetData()); + int64_t dense_num = ctx.Input(3)->GetTensorShape()->NumElements(); + + std::vector sparse_shape(dimension); + for (int64_t i = 0; i < dimension; i++) { + sparse_shape[i] = sparse_shape_data[i]; + } + + bool isNeedBcast = (dense_shape == sparse_shape || dense_num == 1); + if (isNeedBcast) { + return SparseDenseCwiseOpNoBcastCompute(ctx); + } else { + if (dense_dims <= dimension) { + for (int i = dense_dims - 1; i >= 0; --i) { + if ((dense_shape[i] != 1) && (dense_shape[i] != sparse_shape[i + dimension - dense_dims])) { + return KERNEL_STATUS_PARAM_INVALID; + } + } + return SparseDenseCwiseOpBcastCompute(ctx); + } else { + return KERNEL_STATUS_PARAM_INVALID; + } + } + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::ComputeOpComplex(CpuKernelContext &ctx) { + auto *input2_tensor = ctx.Input(2); + auto *input3_tensor = ctx.Input(3); + int64_t dense_num = ctx.Input(3)->GetTensorShape()->NumElements(); + int64_t dimension = input2_tensor->NumElements(); + int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims(); + auto dense_shape = input3_tensor->GetTensorShape()->GetDimSizes(); + auto sparse_shape_data = reinterpret_cast(ctx.Input(2)->GetData()); + + std::vector sparse_shape(dimension); + for (int64_t i = 0; i < dimension; i++) { + sparse_shape[i] = sparse_shape_data[i]; + } + + bool isNeedBcast = (dense_shape == sparse_shape || dense_num == 1); + if (isNeedBcast) { + return SparseDenseCwiseOpNoBcastComputeComplex(ctx); + } else { + if (dense_dims <= dimension) { + for (int i = dense_dims - 1; i >= 0; --i) { + if ((dense_shape[i] != 1) && (dense_shape[i] != sparse_shape[i + dimension - dense_dims])) { + return KERNEL_STATUS_PARAM_INVALID; + } + } + return SparseDenseCwiseOpBcastComputeComplex(ctx); + } else { + return KERNEL_STATUS_PARAM_INVALID; + } + } + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpSpecialComputeComplex(BcastShapeType type, + CpuKernelContext &ctx) { + auto sparse_indices_data = reinterpret_cast(ctx.Input(0)->GetData()); + auto sparse_values_data = reinterpret_cast(ctx.Input(1)->GetData()); + auto sparse_shape_data = reinterpret_cast(ctx.Input(2)->GetData()); + auto dense_data = reinterpret_cast(ctx.Input(3)->GetData()); + auto output_data = reinterpret_cast(ctx.Output(0)->GetData()); + + int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0); + int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1); + int64_t data_num = ctx.Input(1)->NumElements(); + std::vector sparse_values_vec(data_num); + for (int64_t i = 0; i < data_num; i++) { + sparse_values_vec[i] = (sparse_values_data[i]); + } + int64_t dims = ctx.Input(2)->NumElements(); + int64_t Sparse_numelements = 1; + for (int64_t i = 0; i < dims; i++) { + Sparse_numelements *= sparse_shape_data[i]; + } + if (Sparse_numelements >= kParallelDataNumSameShape) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0"); + } + if (max_core_num > value_nums) { + max_core_num = value_nums; + } + + auto sharder_Op = [&](int64_t start, int64_t end) { + switch (type) { + case BcastShapeType::SAME_SHAPE: + for (int64_t i = start; i < end; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += c * sparse_indices_data[j + i * dimension]; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + dense_data[index]; + } else if (name == "Div") { + if (fabs(dense_data[index]) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / dense_data[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * dense_data[index]; + } + } + break; + + case BcastShapeType::Y_ONE_ELEMENT: + for (int64_t i = start; i < end; i++) { + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_data[i] + *(dense_data); + } else if (name == "Div") { + if (fabs(*(dense_data)) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_data[i] / *(dense_data); + } + } else { + output_data[i] = *(dense_data)*sparse_values_data[i]; + } + } + break; + default: + KERNEL_LOG_WARN("Invalid type [%d]", static_cast(type)); + break; + } + return KERNEL_STATUS_OK; + }; + + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op), + "Op Compute failed."); + } else { + switch (type) { + case BcastShapeType::SAME_SHAPE: + for (int64_t i = 0; i < value_nums; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += c * sparse_indices_data[j + i * dimension]; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + dense_data[index]; + } else if (name == "Div") { + if (fabs(dense_data[index]) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / dense_data[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * dense_data[index]; + } + } + break; + + case BcastShapeType::Y_ONE_ELEMENT: + for (int64_t i = 0; i < value_nums; i++) { + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_data[i] + *(dense_data); + } else if (name == "Div") { + if (fabs(*(dense_data)) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_data[i] / *(dense_data); + } + } else { + output_data[i] = *(dense_data)*sparse_values_data[i]; + } + } + break; + default: + KERNEL_LOG_WARN("Invalid type [%d]", static_cast(type)); + break; + } + } + + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpNoBcastComputeComplex(CpuKernelContext &ctx) { + auto *input2_tensor = ctx.Input(2); + auto *input3_tensor = ctx.Input(3); + int64_t dimension = input2_tensor->NumElements(); + int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims(); + + BcastShapeType type = dimension == dense_dims ? BcastShapeType::SAME_SHAPE : BcastShapeType::Y_ONE_ELEMENT; + SparseDenseCwiseOpSpecialComputeComplex(type, ctx); + return KERNEL_STATUS_OK; +} + +template +template +uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpBcastComputeComplex(CpuKernelContext &ctx) { + auto sparse_indices_data = reinterpret_cast(ctx.Input(0)->GetData()); + auto sparse_values_data = reinterpret_cast(ctx.Input(1)->GetData()); + auto sparse_shape_data = reinterpret_cast(ctx.Input(2)->GetData()); + auto dense_data = reinterpret_cast(ctx.Input(3)->GetData()); + auto output_data = reinterpret_cast(ctx.Output(0)->GetData()); + + int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0); + int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1); + auto dense_shape = ctx.Input(3)->GetTensorShape()->GetDimSizes(); + int64_t dims = ctx.Input(2)->NumElements(); + int64_t data_num = ctx.Input(1)->NumElements(); + + int64_t Sparse_numelements = 1; + for (int64_t i = 0; i < dims; i++) { + Sparse_numelements *= sparse_shape_data[i]; + } + + std::vector sparse_values_vec(data_num); + for (int64_t i = 0; i < data_num; i++) { + sparse_values_vec[i] = (sparse_values_data[i]); + } + + std::vector sparse_shape(dimension); + for (int64_t i = 0; i < dimension; i++) { + sparse_shape[i] = sparse_shape_data[i]; + } + std::vector sparse_shape1(dimension); + for (int64_t j = 0; j < dimension; j++) { + sparse_shape1[j] = sparse_shape[j]; + } + + BroadcastIterator broad_base_iter_1(sparse_shape, dense_shape, sparse_shape1); + std::vector Dense(Sparse_numelements); + broad_base_iter_1.SetPos(0); + for (int64_t i = 0; i < Sparse_numelements; i++) { + Dense[i] = dense_data[broad_base_iter_1.GetInputPosB()]; + broad_base_iter_1.GenNextPos(); + } + + if (Sparse_numelements >= kParallelDataNum) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + if (max_core_num == 0) { + KERNEL_LOG_ERROR("max_core_num could not be 0"); + } + if (max_core_num > value_nums) { + max_core_num = value_nums; + } + + auto sharder_Op = [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += sparse_indices_data[j + i * dimension] * c; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + Dense[index]; + } else if (name == "Div") { + if (fabs(Dense[index]) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / Dense[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * Dense[index]; + } + } + return KERNEL_STATUS_OK; + }; + + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op), + "Op Compute failed."); + } else { + for (int64_t i = 0; i < value_nums; i++) { + int index = 0; + for (int64_t j = 0; j < dimension - 1; j++) { + int c = 1; + for (int k = j + 1; k < dimension; k++) { + c = c * sparse_shape_data[k]; + } + index += sparse_indices_data[j + i * dimension] * c; + } + index += sparse_indices_data[(i + 1) * dimension - 1]; + std::string name = Op::Name().c_str(); + if (name == "Add") { + output_data[i] = sparse_values_vec[i] + Dense[index]; + } else if (name == "Div") { + if (fabs(Dense[index]) < 1e-6) { + KERNEL_LOG_ERROR("Cannot be divided by 0"); + return KERNEL_STATUS_PARAM_INVALID; + } else { + output_data[i] = sparse_values_vec[i] / Dense[index]; + } + } else { + output_data[i] = sparse_values_vec[i] * Dense[index]; + } + } + } + return KERNEL_STATUS_OK; +} + +template class SparseDenseCwiseOpKernel; +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute>( + CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute>( + CpuKernelContext &ctx); + +template class SparseDenseCwiseOpKernel; +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute>( + CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute>( + CpuKernelContext &ctx); + +template class SparseDenseCwiseOpKernel; +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute(CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute>( + CpuKernelContext &ctx); +template uint32_t SparseDenseCwiseOpKernel::SparseDenseCwiseOpCompute>( + CpuKernelContext &ctx); +} // namespace aicpu