forked from mindspore-Ecosystem/mindspore
merge canndev code to mindspore
This commit is contained in:
parent
a4e4ab6da7
commit
8103044f80
|
@ -339,3 +339,7 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_max_pool.cc:aicpu::FractionalMaxPoolCpuKernel::DoCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_avg_pool.cc:aicpu::FractionalAvgPoolCpuKernel::DoCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/densetosparsesetoperation.cc:aicpu::DenseToSparseSetOperationCpuKernel::ComputeDenseToSparse
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpSpecialCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpBcastCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpSpecialComputeComplex
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpBcastComputeComplex
|
||||
|
|
|
@ -173,19 +173,7 @@ status_t SetThreadLocalCtx(const std::string &key, const std::string &value) {
|
|||
return AICPU_ERROR_NONE;
|
||||
}
|
||||
|
||||
status_t GetThreadLocalCtx(const std::string &key, std::string *value) {
|
||||
if (key.empty()) {
|
||||
AICPU_LOGE("get thread local context failed, key is empty");
|
||||
return AICPU_ERROR_FAILED;
|
||||
}
|
||||
auto iter = g_thread_local_ctx.find(key);
|
||||
if (iter != g_thread_local_ctx.end()) {
|
||||
*value = iter->second;
|
||||
return AICPU_ERROR_NONE;
|
||||
}
|
||||
AICPU_LOGW("get thread local context failed, no such key[%s]", key.c_str());
|
||||
return AICPU_ERROR_FAILED;
|
||||
}
|
||||
status_t GetThreadLocalCtx(const std::string &key, std::string *value) { return AICPU_ERROR_NONE; }
|
||||
|
||||
status_t RemoveThreadLocalCtx(const std::string &key) {
|
||||
auto iter = g_thread_local_ctx.find(key);
|
||||
|
|
|
@ -0,0 +1,375 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sparse_concat.h"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "Eigen/Core"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "status.h"
|
||||
using namespace std;
|
||||
namespace {
|
||||
const char *kSparseConcat = "SparseConcat";
|
||||
const uint32_t kOutputNum = 3;
|
||||
const uint32_t kInputNum = 3;
|
||||
} // namespace
|
||||
namespace aicpu {
|
||||
class MySparseTensor {
|
||||
public:
|
||||
MySparseTensor() : dims_(0) {}
|
||||
~MySparseTensor() = default;
|
||||
uint32_t CreateSparseTensor(Tensor *ix, Tensor *vals, std::vector<int64_t> shape, std::vector<int64_t> order) {
|
||||
int64_t dims = ix->GetTensorShape()->GetDimSize(1);
|
||||
ix_ = std::make_shared<EigenTensor>(ix, ix->GetData());
|
||||
vals_ = std::make_shared<EigenTensor>(vals, vals->GetData());
|
||||
shape_.assign(shape.begin(), shape.end());
|
||||
order_.assign(order.begin(), order.end());
|
||||
dims_ = dims;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
class DimComparator {
|
||||
public:
|
||||
DimComparator(const TTypes<int64_t>::Matrix &ix, const std::vector<int64_t> &order,
|
||||
const std::vector<int64_t> &shape)
|
||||
: ix_(ix), order_(order), dims_(shape.size()) {}
|
||||
|
||||
inline bool operator()(const int64_t i, const int64_t j) const {
|
||||
for (int di = 0; di < dims_; ++di) {
|
||||
const int64_t d = order_[di];
|
||||
if (ix_(i, d) < ix_(j, d)) {
|
||||
return true;
|
||||
}
|
||||
if (ix_(i, d) > ix_(j, d)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compares two indices taken from corresponding index matrices, using the
|
||||
// standard, row-major (or lexicographic) order. Useful for cases that need
|
||||
// to distinguish between all three orderings (<, ==, >).
|
||||
inline static int cmp(const TTypes<int64_t>::ConstMatrix &a_idx, const TTypes<int64_t>::ConstMatrix &b_idx,
|
||||
const int64_t a_row, const int64_t b_row, const int dims) {
|
||||
for (int d = 0; d < dims; ++d) {
|
||||
const int64_t a = a_idx(a_row, d);
|
||||
const int64_t b = b_idx(b_row, d);
|
||||
if (a < b) {
|
||||
return -1;
|
||||
} else if (a > b) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
const TTypes<int64_t>::Matrix ix_;
|
||||
const std::vector<int64_t> order_;
|
||||
const int dims_;
|
||||
};
|
||||
template <int ORDER_DIM>
|
||||
class FixedDimComparator : DimComparator {
|
||||
public:
|
||||
FixedDimComparator(const TTypes<int64_t>::Matrix &ix, const std::vector<int64_t> &order,
|
||||
const std::vector<int64_t> &shape)
|
||||
: DimComparator(ix, order, shape) {}
|
||||
inline bool operator()(const int64_t i, const int64_t j) const {
|
||||
bool value = false;
|
||||
for (int di = 0; di < ORDER_DIM; ++di) {
|
||||
const int64_t d = order_[di];
|
||||
if (ix_(i, d) < ix_(j, d)) {
|
||||
value = true;
|
||||
break;
|
||||
}
|
||||
if (ix_(i, d) > ix_(j, d)) break;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
uint32_t Reorder(const std::vector<int64_t> &order) {
|
||||
KERNEL_CHECK_FALSE(order.size() == (std::size_t)dims_, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Order length must be SparseTensor rank");
|
||||
auto ix_t = ix_->matrix<int64_t>();
|
||||
auto vals_t = vals_->vec<T>();
|
||||
|
||||
std::vector<int64_t> reorder(ix_->GetTensor()->GetTensorShape()->GetDimSize(0));
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
|
||||
// Sort to get order of indices
|
||||
switch (order.size()) {
|
||||
#define CASE_SORT(ORDER_SIZE) \
|
||||
case ORDER_SIZE: { \
|
||||
FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape_); \
|
||||
std::sort(reorder.begin(), reorder.end(), sorter); \
|
||||
break; \
|
||||
}
|
||||
CASE_SORT(0);
|
||||
CASE_SORT(1);
|
||||
CASE_SORT(2);
|
||||
CASE_SORT(3);
|
||||
CASE_SORT(4);
|
||||
CASE_SORT(5);
|
||||
#undef CASE_SORT
|
||||
default: {
|
||||
DimComparator sorter(ix_t, order, shape_);
|
||||
std::sort(reorder.begin(), reorder.end(), sorter);
|
||||
}
|
||||
}
|
||||
|
||||
// We have a forward reordering, but what we'll need is a
|
||||
// permutation (the inverse). This can be calculated with O(1)
|
||||
// additional
|
||||
// and O(n) time (INVPERM) but we just do the simple thing here.
|
||||
std::vector<size_t> permutation(reorder.size());
|
||||
for (std::size_t n = 0; n < reorder.size(); ++n) {
|
||||
permutation[reorder[n]] = n;
|
||||
}
|
||||
|
||||
// Update indices & values by converting the permutations to
|
||||
// a product of transpositions. Iterate over the cycles in the
|
||||
// permutation, and convert each of those into a product of
|
||||
// transpositions (swaps):
|
||||
// https://en.wikipedia.org/wiki/Cyclic_permutation
|
||||
// This is N swaps, 2*N comparisons.
|
||||
for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
|
||||
while (n != permutation[n]) {
|
||||
std::size_t r = permutation[n];
|
||||
std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
|
||||
std::swap(vals_t(n), vals_t(r));
|
||||
std::swap(permutation[n], permutation[r]);
|
||||
}
|
||||
}
|
||||
|
||||
order_.assign(order.begin(), order.end());
|
||||
return 0;
|
||||
}
|
||||
template <typename T>
|
||||
static MySparseTensor *Concat(const std::vector<MySparseTensor *> &tensors, Tensor *output_ix, Tensor *output_vals) {
|
||||
const int dims = tensors[0]->dims_;
|
||||
auto order_0 = tensors[0]->order_;
|
||||
const int primary_dim = order_0[0];
|
||||
std::vector<int64_t> final_order(order_0.begin(), order_0.end());
|
||||
std::vector<int64_t> final_shape(tensors[0]->shape_.begin(), tensors[0]->shape_.end());
|
||||
final_shape[primary_dim] = 0; // We'll build this up as we go along.
|
||||
int num_entries = 0;
|
||||
|
||||
bool fully_ordered = true;
|
||||
for (const MySparseTensor *st : tensors) {
|
||||
if (st->order_ != final_order) fully_ordered = false;
|
||||
const std::vector<int64_t> &st_shape = st->shape_;
|
||||
|
||||
// Update dimension of final shape
|
||||
final_shape[primary_dim] = (final_shape[primary_dim] + st_shape[primary_dim]);
|
||||
|
||||
num_entries += st->ix_->GetTensor()->GetTensorShape()->GetDimSize(0); // Update number of entries
|
||||
}
|
||||
|
||||
// If nonconsistent ordering among inputs, set final order to -1s.
|
||||
if (!fully_ordered) {
|
||||
final_order = std::vector<int64_t>(final_shape.size(), -1);
|
||||
}
|
||||
|
||||
EigenTensor ixET(output_ix, output_ix->GetData());
|
||||
EigenTensor valsET(output_vals, output_vals->GetData());
|
||||
TTypes<int64_t>::Matrix ix_t = ixET.matrix<int64_t>();
|
||||
typename TTypes<T>::Vec vals_t = valsET.vec<T>();
|
||||
|
||||
Eigen::DenseIndex offset = 0;
|
||||
int64_t shape_offset = 0;
|
||||
for (const MySparseTensor *st : tensors) {
|
||||
const int st_num_entries = st->ix_->GetTensor()->GetTensorShape()->GetDimSize(0);
|
||||
|
||||
// Fill in indices & values.
|
||||
std::copy_n(&st->vals_->vec<T>()(0), st_num_entries, &vals_t(offset));
|
||||
|
||||
const auto *st_ix = &st->ix_->matrix<int64_t>()(0, 0);
|
||||
auto *ix_out = &ix_t(offset, 0);
|
||||
for (int i = 0; i < st_num_entries * dims; ++i) {
|
||||
*ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
|
||||
}
|
||||
|
||||
offset += st_num_entries;
|
||||
shape_offset += st->shape_[primary_dim];
|
||||
}
|
||||
MySparseTensor *res = new MySparseTensor();
|
||||
res->CreateSparseTensor(output_ix, output_vals, final_shape, final_order);
|
||||
return res;
|
||||
}
|
||||
std::vector<int64_t> shape() { return shape_; };
|
||||
|
||||
private:
|
||||
std::shared_ptr<EigenTensor> ix_;
|
||||
std::shared_ptr<EigenTensor> vals_;
|
||||
std::vector<int64_t> shape_;
|
||||
std::vector<int64_t> order_;
|
||||
int32_t dims_;
|
||||
};
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx) {
|
||||
int64_t concat_dim_attr_ = ctx.GetAttr("concat_dim") != NULL ? ctx.GetAttr("concat_dim")->GetInt() : 0;
|
||||
int64_t N = ctx.GetAttr("N") != NULL ? ctx.GetAttr("N")->GetInt() : 1;
|
||||
|
||||
vector<Tensor *> inds;
|
||||
vector<Tensor *> vals;
|
||||
vector<Tensor *> shapes;
|
||||
|
||||
vector<typename TTypes<int64_t>::Matrix> inds_t;
|
||||
vector<typename TTypes<T>::Vec> vals_t;
|
||||
vector<typename TTypes<int64_t>::Vec> shapes_t;
|
||||
for (int i = 0; i < N; i++) {
|
||||
Tensor *indice = ctx.Input(i);
|
||||
Tensor *value = ctx.Input(i + N);
|
||||
Tensor *shape = ctx.Input(i + N * 2);
|
||||
|
||||
auto indice_shape = indice->GetTensorShape();
|
||||
const int indices_dim = 2;
|
||||
KERNEL_CHECK_FALSE(indice_shape->GetDims() == indices_dim, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input indices should be a matrix but received shape %d at position %d", indice_shape->GetDims(),
|
||||
i);
|
||||
|
||||
auto value_shape = value->GetTensorShape();
|
||||
KERNEL_CHECK_FALSE(value_shape->GetDims() == 1, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input values should be a vector but received shape %d at position %d", value_shape->GetDims(),
|
||||
i);
|
||||
|
||||
auto shape_shape = shape->GetTensorShape();
|
||||
KERNEL_CHECK_FALSE(shape_shape->GetDims() == 1, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input shapes should be a vector but received shape %d at position %d", shape_shape->GetDims(),
|
||||
i);
|
||||
|
||||
int64_t ind_dim0 = indice_shape->GetDimSize(0);
|
||||
int64_t ind_dim1 = indice_shape->GetDimSize(1);
|
||||
int64_t val_dim0 = value_shape->GetDimSize(0);
|
||||
int64_t shape_dim0 = shape_shape->GetDimSize(0);
|
||||
|
||||
KERNEL_CHECK_FALSE(ind_dim0 == val_dim0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"indices dim_size_0 [%lld] != values dim_size_0 [%lld] at position %d", ind_dim0, val_dim0, i);
|
||||
KERNEL_CHECK_FALSE(ind_dim1 == shape_dim0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"indices dim_size_1 [%lld] != shapes dim_size_0 [%lld] at position %d", ind_dim1, shape_dim0, i);
|
||||
|
||||
EigenTensor indiceET(indice, indice->GetData());
|
||||
EigenTensor valueET(value, value->GetData());
|
||||
EigenTensor shapeET(shape, shape->GetData());
|
||||
inds_t.push_back(indiceET.matrix<int64_t>());
|
||||
vals_t.push_back(valueET.vec<T>());
|
||||
shapes_t.push_back(shapeET.vec<int64_t>());
|
||||
|
||||
inds.push_back(indice);
|
||||
vals.push_back(value);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
const typename TTypes<int64_t>::Vec input_shape = shapes_t[0];
|
||||
const int input_rank = input_shape.size();
|
||||
const int concat_dim = (concat_dim_attr_ < 0) ? input_rank + concat_dim_attr_ : concat_dim_attr_;
|
||||
KERNEL_CHECK_FALSE(concat_dim >= 0 && concat_dim < input_rank, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Concat dimension must be in range [%d,%d),got %d", -input_rank, input_rank, concat_dim_attr_);
|
||||
for (int i = 1; i < N; i++) {
|
||||
const typename TTypes<int64_t>::Vec current_shape = shapes_t[i];
|
||||
KERNEL_CHECK_FALSE(current_shape.size() == input_rank, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Ranks of all input tensors must match: expected %d,but "
|
||||
"got %d at position %d",
|
||||
input_rank, current_shape.size(), i);
|
||||
for (int j = 0; j < input_rank; j++) {
|
||||
if (j != concat_dim) {
|
||||
KERNEL_CHECK_FALSE(input_shape(j) == current_shape(j), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input shapes must match: expected %d for dimension "
|
||||
"%d but got %d at position %d",
|
||||
input_shape(j), j, current_shape(j), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
vector<int64_t> std_order(input_rank);
|
||||
iota(std_order.begin(), std_order.end(), 0);
|
||||
|
||||
vector<int64_t> concat_order;
|
||||
concat_order.reserve(input_rank);
|
||||
concat_order.push_back(concat_dim);
|
||||
for (int j = 0; j < input_rank; ++j) {
|
||||
if (j != concat_dim) {
|
||||
concat_order.push_back(j);
|
||||
}
|
||||
}
|
||||
vector<MySparseTensor *> sp_inputs;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vector<int64_t> current_shape;
|
||||
for (int j = 0; j < input_rank; j++) current_shape.push_back(shapes_t[i](j));
|
||||
MySparseTensor *tensor = new MySparseTensor();
|
||||
tensor->CreateSparseTensor(inds[i], vals[i], current_shape, std_order);
|
||||
sp_inputs.push_back(std::move(tensor));
|
||||
sp_inputs[i]->Reorder<T>(concat_order);
|
||||
}
|
||||
Tensor *output_ix = ctx.Output(0);
|
||||
Tensor *output_vals = ctx.Output(1);
|
||||
|
||||
MySparseTensor *concat = MySparseTensor::Concat<T>(sp_inputs, output_ix, output_vals);
|
||||
concat->Reorder<T>(std_order);
|
||||
|
||||
Tensor *output_shape_out = ctx.Output(2);
|
||||
EigenTensor output_shapeET(output_shape_out, output_shape_out->GetData());
|
||||
auto output_shape = output_shapeET.vec<int64_t>();
|
||||
auto concat_shape = concat->shape();
|
||||
for (std::size_t i = 0; i < concat_shape.size(); i++) {
|
||||
output_shape(i) = concat_shape[i];
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t SparseConcatCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
int64_t N = ctx.GetAttr("N") != NULL ? ctx.GetAttr("N")->GetInt() : 1;
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, N * kInputNum, kOutputNum),
|
||||
"SparseConcat check input and output number failed.");
|
||||
auto data_type = ctx.Input(N)->GetDataType();
|
||||
switch (data_type) {
|
||||
case DT_INT8:
|
||||
return DoCompute<int8_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DoCompute<uint8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DoCompute<int16_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DoCompute<uint16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DoCompute<int32_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DoCompute<uint32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DoCompute<int64_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DoCompute<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DoCompute<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DoCompute<float>(ctx);
|
||||
case DT_BOOL:
|
||||
return DoCompute<bool>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DoCompute<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DoCompute<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DoCompute<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("SparseConcat kernel data type [%u] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kSparseConcat, SparseConcatCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_SPARSE_CONCAT_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_SPARSE_CONCAT_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class SparseConcatCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~SparseConcatCpuKernel() = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,48 @@
|
|||
#include "sparse_dense_cwise_add.h"
|
||||
#include <iostream>
|
||||
#include "utils/kernel_util.h"
|
||||
#include "utils/sparse_dense_cwise_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
const char *kSparseDenseCwiseAdd = "SparseDenseCwiseAdd";
|
||||
|
||||
#define SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = SparseDenseCwiseOpCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("SparseDenseCwiseAdd kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t SparseDenseCwiseAddKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(CheckParams(ctx), "SparseDenseCwiseAdd check params failed.");
|
||||
|
||||
auto data_type = ctx.Input(1)->GetDataType();
|
||||
switch (data_type) {
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
|
||||
SPARSE_DENSE_CWISE_ADD_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("SparseDenseCwiseAdd kernel data type %s not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kSparseDenseCwiseAdd, SparseDenseCwiseAddKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef AICPU_KERNELS_SPARSE_DENSE_CWISE_ADD_H_
|
||||
#define AICPU_KERNELS_SPARSE_DENSE_CWISE_ADD_H_
|
||||
|
||||
#include "utils/sparse_dense_cwise_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
class SparseDenseCwiseAddKernel : public SparseDenseCwiseOpKernel<AddOp> {
|
||||
public:
|
||||
SparseDenseCwiseAddKernel() = default;
|
||||
~SparseDenseCwiseAddKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,48 @@
|
|||
#include "sparse_dense_cwise_div.h"
|
||||
#include <iostream>
|
||||
#include "utils/kernel_util.h"
|
||||
#include "utils/sparse_dense_cwise_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
const char *kSparseDenseCwiseDiv = "SparseDenseCwiseDiv";
|
||||
|
||||
#define SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = SparseDenseCwiseOpCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("SparseDenseCwiseDiv kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t SparseDenseCwiseDivKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(CheckParams(ctx), "SparseDenseCwiseADiv check params failed.");
|
||||
|
||||
auto data_type = ctx.Input(1)->GetDataType();
|
||||
switch (data_type) {
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
|
||||
SPARSE_DENSE_CWISE_DIV_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("SparseDenseCwiseDiv kernel data type %s not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kSparseDenseCwiseDiv, SparseDenseCwiseDivKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef AICPU_KERNELS_SPARSE_DENSE_CWISE_DIV_H_
|
||||
#define AICPU_KERNELS_SPARSE_DENSE_CWISE_DIV_H_
|
||||
|
||||
#include "utils/sparse_dense_cwise_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
class SparseDenseCwiseDivKernel : public SparseDenseCwiseOpKernel<DivOp> {
|
||||
public:
|
||||
SparseDenseCwiseDivKernel() = default;
|
||||
~SparseDenseCwiseDivKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,48 @@
|
|||
#include "sparse_dense_cwise_mul.h"
|
||||
#include <iostream>
|
||||
#include "utils/kernel_util.h"
|
||||
#include "utils/sparse_dense_cwise_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
const char *kSparseDenseCwiseMul = "SparseDenseCwiseMul";
|
||||
|
||||
#define SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = SparseDenseCwiseOpCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("SparseDenseCwiseMul kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t SparseDenseCwiseMulKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(CheckParams(ctx), "SparseDenseCwiseMul check params failed.");
|
||||
|
||||
auto data_type = ctx.Input(1)->GetDataType();
|
||||
switch (data_type) {
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
|
||||
SPARSE_DENSE_CWISE_MUL_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("SparseDenseCwiseMul kernel data type %s not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kSparseDenseCwiseMul, SparseDenseCwiseMulKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef AICPU_KERNELS_SPARSE_DENSE_CWISE_MUL_H_
|
||||
#define AICPU_KERNELS_SPARSE_DENSE_CWISE_MUL_H_
|
||||
|
||||
#include "utils/sparse_dense_cwise_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
class SparseDenseCwiseMulKernel : public SparseDenseCwiseOpKernel<MulOp> {
|
||||
public:
|
||||
SparseDenseCwiseMulKernel() = default;
|
||||
~SparseDenseCwiseMulKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,754 @@
|
|||
#include "sparse_dense_cwise_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "broadcast_iterator.h"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "utils/sparse_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
const uint32_t kInputNum_SparseDenseCwiseOp = 4;
|
||||
const uint32_t kOutputNum_SparseDenseCwiseOp = 1;
|
||||
const int64_t kParallelDataNum = 2 * 1024;
|
||||
const int64_t kParallelDataNumSameShape = 7 * 1024;
|
||||
} // namespace
|
||||
|
||||
template <typename Op>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::CheckParams(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum_SparseDenseCwiseOp, kOutputNum_SparseDenseCwiseOp),
|
||||
"SparseDenseCwise%s normal check failed.", Op::Name().c_str());
|
||||
|
||||
Tensor *x1_indices = ctx.Input(0);
|
||||
Tensor *x1_values = ctx.Input(1);
|
||||
Tensor *x1_shape = ctx.Input(2);
|
||||
Tensor *x2 = ctx.Input(3);
|
||||
Tensor *y = ctx.Output(0);
|
||||
|
||||
DataType x1_indices_type = x1_indices->GetDataType();
|
||||
DataType x1_values_type = x1_values->GetDataType();
|
||||
DataType x1_shape_type = x1_shape->GetDataType();
|
||||
DataType x2_type = x2->GetDataType();
|
||||
DataType y_type = y->GetDataType();
|
||||
KERNEL_CHECK_FALSE((x1_indices_type == x1_shape_type), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of x1_indices_type [%s] need be same with "
|
||||
"x1_shape [%s].",
|
||||
DTypeStr(x1_indices_type).c_str(), DTypeStr(x1_shape_type).c_str())
|
||||
KERNEL_CHECK_FALSE(((x1_values_type == x2_type) && (x1_values_type == y_type)), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of x1_values_type [%s] need be same with "
|
||||
"x2_type[%s] and y_type [%s].",
|
||||
DTypeStr(x1_values_type).c_str(), DTypeStr(x2_type).c_str(), DTypeStr(y_type).c_str())
|
||||
KERNEL_CHECK_FALSE((x1_indices_type == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of x1_indices_type [%s] need be int_64.", DTypeStr(x1_indices_type).c_str())
|
||||
int32_t input0_dims = x1_indices->GetTensorShape()->GetDims();
|
||||
int32_t input1_dims = x1_values->GetTensorShape()->GetDims();
|
||||
int32_t input2_dims = x1_shape->GetTensorShape()->GetDims();
|
||||
int32_t input3_dims = x2->GetTensorShape()->GetDims();
|
||||
int32_t output_dims = y->GetTensorShape()->GetDims();
|
||||
int64_t shape_elements_nums = x1_shape->GetTensorShape()->NumElements();
|
||||
int64_t indices_0 = x1_indices->GetTensorShape()->GetDimSize(0);
|
||||
int64_t value_0 = x1_values->GetTensorShape()->GetDimSize(0);
|
||||
KERNEL_CHECK_FALSE((int(input0_dims) == 2), KERNEL_STATUS_PARAM_INVALID, "The dims of input0 need be 2.")
|
||||
KERNEL_CHECK_FALSE((input1_dims == 1), KERNEL_STATUS_PARAM_INVALID, "The dims of input1 need be 1 .")
|
||||
KERNEL_CHECK_FALSE((input2_dims == 1), KERNEL_STATUS_PARAM_INVALID, "The dims of input2 need be 1.")
|
||||
KERNEL_CHECK_FALSE((output_dims == 1), KERNEL_STATUS_PARAM_INVALID, "The dims of output need be 1.")
|
||||
KERNEL_CHECK_FALSE((input3_dims <= shape_elements_nums), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The dims of DenseTensor is large than sparseTensor.")
|
||||
KERNEL_CHECK_FALSE((indices_0 == value_0), KERNEL_STATUS_PARAM_INVALID, "The num of indices is not equal to value.")
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpSpecialCompute(BcastShapeType type, CpuKernelContext &ctx) {
|
||||
auto sparse_indices_data = reinterpret_cast<int64_t *>(ctx.Input(0)->GetData());
|
||||
auto sparse_values_data = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto sparse_shape_data = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
auto dense_data = reinterpret_cast<T *>(ctx.Input(3)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
||||
int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0);
|
||||
int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1);
|
||||
int64_t data_num = ctx.Input(1)->NumElements();
|
||||
|
||||
std::vector<T> sparse_values_vec(data_num);
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
sparse_values_vec[i] = (sparse_values_data[i]);
|
||||
}
|
||||
int64_t dims = ctx.Input(2)->NumElements();
|
||||
int64_t Sparse_numelements = 1;
|
||||
for (int64_t i = 0; i < dims; i++) {
|
||||
Sparse_numelements *= sparse_shape_data[i];
|
||||
}
|
||||
if (Sparse_numelements >= kParallelDataNumSameShape) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0");
|
||||
}
|
||||
if (max_core_num > value_nums) {
|
||||
max_core_num = value_nums;
|
||||
}
|
||||
|
||||
auto sharder_Op = [&](int64_t start, int64_t end) {
|
||||
switch (type) {
|
||||
case BcastShapeType::SAME_SHAPE:
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += c * sparse_indices_data[j + i * dimension];
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + dense_data[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(double(dense_data[index])) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / dense_data[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * dense_data[index];
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case BcastShapeType::Y_ONE_ELEMENT:
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_data[i] + *(dense_data);
|
||||
} else if (name == "Div") {
|
||||
if (fabs(double(*(dense_data))) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_data[i] / *(dense_data);
|
||||
}
|
||||
} else {
|
||||
output_data[i] = *(dense_data)*sparse_values_data[i];
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
|
||||
break;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op),
|
||||
"Op Compute failed.");
|
||||
} else {
|
||||
switch (type) {
|
||||
case BcastShapeType::SAME_SHAPE:
|
||||
for (int64_t i = 0; i < value_nums; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += c * sparse_indices_data[j + i * dimension];
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + dense_data[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(double(dense_data[index])) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / dense_data[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * dense_data[index];
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case BcastShapeType::Y_ONE_ELEMENT:
|
||||
for (int64_t i = 0; i < value_nums; i++) {
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_data[i] + *(dense_data);
|
||||
} else if (name == "Div") {
|
||||
if (fabs(double(*(dense_data))) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_data[i] / *(dense_data);
|
||||
}
|
||||
} else {
|
||||
output_data[i] = *(dense_data)*sparse_values_data[i];
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpNoBcastCompute(CpuKernelContext &ctx) {
|
||||
auto *input2_tensor = ctx.Input(2);
|
||||
auto *input3_tensor = ctx.Input(3);
|
||||
int64_t dimension = input2_tensor->NumElements();
|
||||
int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims();
|
||||
BcastShapeType type = dimension == dense_dims ? BcastShapeType::SAME_SHAPE : BcastShapeType::Y_ONE_ELEMENT;
|
||||
SparseDenseCwiseOpSpecialCompute<T>(type, ctx);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpBcastCompute(CpuKernelContext &ctx) {
|
||||
auto sparse_indices_data = reinterpret_cast<int64_t *>(ctx.Input(0)->GetData());
|
||||
auto sparse_values_data = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto sparse_shape_data = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
auto dense_data = reinterpret_cast<T *>(ctx.Input(3)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
||||
int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0);
|
||||
int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1);
|
||||
auto dense_shape = ctx.Input(3)->GetTensorShape()->GetDimSizes();
|
||||
int64_t dims = ctx.Input(2)->NumElements();
|
||||
int64_t data_num = ctx.Input(1)->NumElements();
|
||||
|
||||
int64_t Sparse_numelements = 1;
|
||||
for (int64_t i = 0; i < dims; i++) {
|
||||
Sparse_numelements *= sparse_shape_data[i];
|
||||
}
|
||||
|
||||
std::vector<T> sparse_values_vec(data_num);
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
sparse_values_vec[i] = (sparse_values_data[i]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> sparse_shape(dimension);
|
||||
for (int64_t i = 0; i < dimension; i++) {
|
||||
sparse_shape[i] = sparse_shape_data[i];
|
||||
}
|
||||
std::vector<int64_t> sparse_shape1(dimension);
|
||||
for (int64_t j = 0; j < dimension; j++) {
|
||||
sparse_shape1[j] = sparse_shape[j];
|
||||
}
|
||||
|
||||
BroadcastIterator broad_base_iter_1(sparse_shape, dense_shape, sparse_shape1);
|
||||
std::vector<T> Dense(Sparse_numelements);
|
||||
broad_base_iter_1.SetPos(0);
|
||||
for (int64_t i = 0; i < Sparse_numelements; i++) {
|
||||
Dense[i] = dense_data[broad_base_iter_1.GetInputPosB()];
|
||||
broad_base_iter_1.GenNextPos();
|
||||
}
|
||||
if (Sparse_numelements >= kParallelDataNum) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0");
|
||||
}
|
||||
if (max_core_num > value_nums) {
|
||||
max_core_num = value_nums;
|
||||
}
|
||||
auto sharder_Op = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += sparse_indices_data[j + i * dimension] * c;
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + Dense[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(double(Dense[index])) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / Dense[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * Dense[index];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op),
|
||||
"Op Compute failed.");
|
||||
} else {
|
||||
for (int64_t i = 0; i < value_nums; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += sparse_indices_data[j + i * dimension] * c;
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + Dense[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(double(Dense[index])) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / Dense[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * Dense[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpCompute(CpuKernelContext &ctx) {
|
||||
auto data_type = ctx.Input(1)->GetDataType();
|
||||
switch (data_type) {
|
||||
case DT_INT8:
|
||||
return ComputeOp<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return ComputeOp<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return ComputeOp<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return ComputeOp<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return ComputeOp<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return ComputeOp<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return ComputeOp<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return ComputeOp<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return ComputeOp<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return ComputeOp<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return ComputeOp<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return ComputeOpComplex<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return ComputeOpComplex<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("sparse_dense_cwise kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::ComputeOp(CpuKernelContext &ctx) {
|
||||
auto *input3_tensor = ctx.Input(3);
|
||||
auto dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1);
|
||||
int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims();
|
||||
auto dense_shape = input3_tensor->GetTensorShape()->GetDimSizes();
|
||||
auto sparse_shape_data = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
int64_t dense_num = ctx.Input(3)->GetTensorShape()->NumElements();
|
||||
|
||||
std::vector<int64_t> sparse_shape(dimension);
|
||||
for (int64_t i = 0; i < dimension; i++) {
|
||||
sparse_shape[i] = sparse_shape_data[i];
|
||||
}
|
||||
|
||||
bool isNeedBcast = (dense_shape == sparse_shape || dense_num == 1);
|
||||
if (isNeedBcast) {
|
||||
return SparseDenseCwiseOpNoBcastCompute<T>(ctx);
|
||||
} else {
|
||||
if (dense_dims <= dimension) {
|
||||
for (int i = dense_dims - 1; i >= 0; --i) {
|
||||
if ((dense_shape[i] != 1) && (dense_shape[i] != sparse_shape[i + dimension - dense_dims])) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return SparseDenseCwiseOpBcastCompute<T>(ctx);
|
||||
} else {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::ComputeOpComplex(CpuKernelContext &ctx) {
|
||||
auto *input2_tensor = ctx.Input(2);
|
||||
auto *input3_tensor = ctx.Input(3);
|
||||
int64_t dense_num = ctx.Input(3)->GetTensorShape()->NumElements();
|
||||
int64_t dimension = input2_tensor->NumElements();
|
||||
int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims();
|
||||
auto dense_shape = input3_tensor->GetTensorShape()->GetDimSizes();
|
||||
auto sparse_shape_data = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
|
||||
std::vector<int64_t> sparse_shape(dimension);
|
||||
for (int64_t i = 0; i < dimension; i++) {
|
||||
sparse_shape[i] = sparse_shape_data[i];
|
||||
}
|
||||
|
||||
bool isNeedBcast = (dense_shape == sparse_shape || dense_num == 1);
|
||||
if (isNeedBcast) {
|
||||
return SparseDenseCwiseOpNoBcastComputeComplex<T>(ctx);
|
||||
} else {
|
||||
if (dense_dims <= dimension) {
|
||||
for (int i = dense_dims - 1; i >= 0; --i) {
|
||||
if ((dense_shape[i] != 1) && (dense_shape[i] != sparse_shape[i + dimension - dense_dims])) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return SparseDenseCwiseOpBcastComputeComplex<T>(ctx);
|
||||
} else {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpSpecialComputeComplex(BcastShapeType type,
|
||||
CpuKernelContext &ctx) {
|
||||
auto sparse_indices_data = reinterpret_cast<int64_t *>(ctx.Input(0)->GetData());
|
||||
auto sparse_values_data = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto sparse_shape_data = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
auto dense_data = reinterpret_cast<T *>(ctx.Input(3)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
||||
int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0);
|
||||
int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1);
|
||||
int64_t data_num = ctx.Input(1)->NumElements();
|
||||
std::vector<T> sparse_values_vec(data_num);
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
sparse_values_vec[i] = (sparse_values_data[i]);
|
||||
}
|
||||
int64_t dims = ctx.Input(2)->NumElements();
|
||||
int64_t Sparse_numelements = 1;
|
||||
for (int64_t i = 0; i < dims; i++) {
|
||||
Sparse_numelements *= sparse_shape_data[i];
|
||||
}
|
||||
if (Sparse_numelements >= kParallelDataNumSameShape) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0");
|
||||
}
|
||||
if (max_core_num > value_nums) {
|
||||
max_core_num = value_nums;
|
||||
}
|
||||
|
||||
auto sharder_Op = [&](int64_t start, int64_t end) {
|
||||
switch (type) {
|
||||
case BcastShapeType::SAME_SHAPE:
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += c * sparse_indices_data[j + i * dimension];
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + dense_data[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(dense_data[index]) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / dense_data[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * dense_data[index];
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case BcastShapeType::Y_ONE_ELEMENT:
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_data[i] + *(dense_data);
|
||||
} else if (name == "Div") {
|
||||
if (fabs(*(dense_data)) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_data[i] / *(dense_data);
|
||||
}
|
||||
} else {
|
||||
output_data[i] = *(dense_data)*sparse_values_data[i];
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
|
||||
break;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op),
|
||||
"Op Compute failed.");
|
||||
} else {
|
||||
switch (type) {
|
||||
case BcastShapeType::SAME_SHAPE:
|
||||
for (int64_t i = 0; i < value_nums; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += c * sparse_indices_data[j + i * dimension];
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + dense_data[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(dense_data[index]) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / dense_data[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * dense_data[index];
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case BcastShapeType::Y_ONE_ELEMENT:
|
||||
for (int64_t i = 0; i < value_nums; i++) {
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_data[i] + *(dense_data);
|
||||
} else if (name == "Div") {
|
||||
if (fabs(*(dense_data)) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_data[i] / *(dense_data);
|
||||
}
|
||||
} else {
|
||||
output_data[i] = *(dense_data)*sparse_values_data[i];
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpNoBcastComputeComplex(CpuKernelContext &ctx) {
|
||||
auto *input2_tensor = ctx.Input(2);
|
||||
auto *input3_tensor = ctx.Input(3);
|
||||
int64_t dimension = input2_tensor->NumElements();
|
||||
int32_t dense_dims = input3_tensor->GetTensorShape()->GetDims();
|
||||
|
||||
BcastShapeType type = dimension == dense_dims ? BcastShapeType::SAME_SHAPE : BcastShapeType::Y_ONE_ELEMENT;
|
||||
SparseDenseCwiseOpSpecialComputeComplex<T>(type, ctx);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
template <typename T>
|
||||
uint32_t SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpBcastComputeComplex(CpuKernelContext &ctx) {
|
||||
auto sparse_indices_data = reinterpret_cast<int64_t *>(ctx.Input(0)->GetData());
|
||||
auto sparse_values_data = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto sparse_shape_data = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
auto dense_data = reinterpret_cast<T *>(ctx.Input(3)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
||||
int64_t value_nums = ctx.Input(0)->GetTensorShape()->GetDimSize(0);
|
||||
int64_t dimension = ctx.Input(0)->GetTensorShape()->GetDimSize(1);
|
||||
auto dense_shape = ctx.Input(3)->GetTensorShape()->GetDimSizes();
|
||||
int64_t dims = ctx.Input(2)->NumElements();
|
||||
int64_t data_num = ctx.Input(1)->NumElements();
|
||||
|
||||
int64_t Sparse_numelements = 1;
|
||||
for (int64_t i = 0; i < dims; i++) {
|
||||
Sparse_numelements *= sparse_shape_data[i];
|
||||
}
|
||||
|
||||
std::vector<T> sparse_values_vec(data_num);
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
sparse_values_vec[i] = (sparse_values_data[i]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> sparse_shape(dimension);
|
||||
for (int64_t i = 0; i < dimension; i++) {
|
||||
sparse_shape[i] = sparse_shape_data[i];
|
||||
}
|
||||
std::vector<int64_t> sparse_shape1(dimension);
|
||||
for (int64_t j = 0; j < dimension; j++) {
|
||||
sparse_shape1[j] = sparse_shape[j];
|
||||
}
|
||||
|
||||
BroadcastIterator broad_base_iter_1(sparse_shape, dense_shape, sparse_shape1);
|
||||
std::vector<T> Dense(Sparse_numelements);
|
||||
broad_base_iter_1.SetPos(0);
|
||||
for (int64_t i = 0; i < Sparse_numelements; i++) {
|
||||
Dense[i] = dense_data[broad_base_iter_1.GetInputPosB()];
|
||||
broad_base_iter_1.GenNextPos();
|
||||
}
|
||||
|
||||
if (Sparse_numelements >= kParallelDataNum) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0");
|
||||
}
|
||||
if (max_core_num > value_nums) {
|
||||
max_core_num = value_nums;
|
||||
}
|
||||
|
||||
auto sharder_Op = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += sparse_indices_data[j + i * dimension] * c;
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + Dense[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(Dense[index]) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / Dense[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * Dense[index];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, value_nums, value_nums / max_core_num, sharder_Op),
|
||||
"Op Compute failed.");
|
||||
} else {
|
||||
for (int64_t i = 0; i < value_nums; i++) {
|
||||
int index = 0;
|
||||
for (int64_t j = 0; j < dimension - 1; j++) {
|
||||
int c = 1;
|
||||
for (int k = j + 1; k < dimension; k++) {
|
||||
c = c * sparse_shape_data[k];
|
||||
}
|
||||
index += sparse_indices_data[j + i * dimension] * c;
|
||||
}
|
||||
index += sparse_indices_data[(i + 1) * dimension - 1];
|
||||
std::string name = Op::Name().c_str();
|
||||
if (name == "Add") {
|
||||
output_data[i] = sparse_values_vec[i] + Dense[index];
|
||||
} else if (name == "Div") {
|
||||
if (fabs(Dense[index]) < 1e-6) {
|
||||
KERNEL_LOG_ERROR("Cannot be divided by 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] / Dense[index];
|
||||
}
|
||||
} else {
|
||||
output_data[i] = sparse_values_vec[i] * Dense[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template class SparseDenseCwiseOpKernel<AddOp>;
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<int8_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<int16_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<int32_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<int64_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<uint8_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<uint16_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<uint32_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<uint64_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<Eigen::half>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<float>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<double>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<std::complex<float>>(
|
||||
CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<AddOp>::SparseDenseCwiseOpCompute<std::complex<double>>(
|
||||
CpuKernelContext &ctx);
|
||||
|
||||
template class SparseDenseCwiseOpKernel<DivOp>;
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<int8_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<int16_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<int32_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<int64_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<uint8_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<uint16_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<uint32_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<uint64_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<Eigen::half>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<float>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<double>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<std::complex<float>>(
|
||||
CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<DivOp>::SparseDenseCwiseOpCompute<std::complex<double>>(
|
||||
CpuKernelContext &ctx);
|
||||
|
||||
template class SparseDenseCwiseOpKernel<MulOp>;
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<int8_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<int16_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<int32_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<int64_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<uint8_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<uint16_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<uint32_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<uint64_t>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<Eigen::half>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<float>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<double>(CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<std::complex<float>>(
|
||||
CpuKernelContext &ctx);
|
||||
template uint32_t SparseDenseCwiseOpKernel<MulOp>::SparseDenseCwiseOpCompute<std::complex<double>>(
|
||||
CpuKernelContext &ctx);
|
||||
} // namespace aicpu
|
Loading…
Reference in New Issue