!36418 Modify input of sparse add

Merge pull request !36418 from YijieChen/sparse_add_dev
This commit is contained in:
i-robot 2022-06-28 03:30:40 +00:00 committed by Gitee
commit 3681ab0ecb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 84 additions and 101 deletions

View File

@ -15,45 +15,48 @@
*/
#include "plugin/device/cpu/kernel/sparse_add_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <set>
#include <map>
#include <set>
#include <utility>
#include "include/common/thread_pool.h"
#include "mindspore/core/ops/sparse_add.h"
namespace mindspore {
namespace kernel {
// Value check constant
constexpr size_t kInputNum = 4;
constexpr size_t kOutputNum = 2;
constexpr size_t kInputNum = 7;
constexpr size_t kOutputNum = 3;
constexpr size_t kNumOfColumn = 2;
// Input idx constant
constexpr size_t kAIndicesIdx = 0;
constexpr size_t kAValuesIdx = 1;
constexpr size_t kBIndicesIdx = 2;
constexpr size_t kBValuesIdx = 3;
constexpr size_t kAShapeIdx = 2;
constexpr size_t kBIndicesIdx = 3;
constexpr size_t kBValuesIdx = 4;
constexpr size_t kBShapeIdx = 5;
constexpr size_t kThreshIdx = 6;
// Output idx constant
constexpr size_t kSumIndicesIdx = 0;
constexpr size_t kSumValuesIdx = 1;
constexpr size_t kSumShapeIdx = 2;
bool SparseAddCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
outputs_ = outputs;
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseAdd>(base_operator);
thresh_ = kernel_ptr->get_thresh();
kernel_name_ = kernel_ptr->name();
size_t input_num = inputs.size();
if (input_num != kInputNum) {
MS_LOG(ERROR) << "For " << kernel_name_ << ", input should be a_indices, a_values, b_indices and b_values total "
MS_LOG(ERROR) << "For " << kernel_name_
<< ", input should be a_indices, a_values, a_shape, b_indices, b_values, b_shape and thresh total "
<< kInputNum << " tensors, but get " << input_num;
return false;
}
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
auto dense_shape = kernel_ptr->get_a_dense_shape();
row_ = LongToSize(dense_shape[0]);
dense_size_ = row_ * LongToSize(dense_shape[1]) * GetTypeByte(TypeIdToType(types_[1]));
is_need_retrieve_output_shape_ = true;
for (size_t i = 0; i < kOutputNum; i++) {
auto dtype = inputs[i]->GetDtype();
@ -71,11 +74,10 @@ int SparseAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
MS_LOG(ERROR) << "Input size list should be " << kInputNum << ", but got " << input_size_list_.size();
return KRET_RESIZE_FAILED;
}
auto max_indices_out_size =
std::min(input_size_list_[kAIndicesIdx] + input_size_list_[kBIndicesIdx], dense_size_ * 2);
auto max_value_out_size = std::min(input_size_list_[kAValuesIdx] + input_size_list_[kBValuesIdx], dense_size_);
output_size_list_.emplace_back(max_indices_out_size);
output_size_list_.emplace_back(max_value_out_size);
auto max_indices_out_size = input_size_list_[kAIndicesIdx] + input_size_list_[kBIndicesIdx];
auto max_value_out_size = input_size_list_[kAValuesIdx] + input_size_list_[kBValuesIdx];
output_size_list_[kSumIndicesIdx] = max_indices_out_size;
output_size_list_[kSumValuesIdx] = max_value_out_size;
}
return ret;
}
@ -109,11 +111,14 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
// Inputs
const auto a_indices = reinterpret_cast<T *>(inputs[kAIndicesIdx]->addr);
const auto a_values = reinterpret_cast<S *>(inputs[kAValuesIdx]->addr);
const auto a_shape = reinterpret_cast<int *>(inputs[kAShapeIdx]->addr);
const auto b_indices = reinterpret_cast<T *>(inputs[kBIndicesIdx]->addr);
const auto b_values = reinterpret_cast<S *>(inputs[kBValuesIdx]->addr);
const auto thresh = reinterpret_cast<float *>(inputs[kThreshIdx]->addr);
// Outputs
auto sum_indices = reinterpret_cast<T *>(outputs[kSumIndicesIdx]->addr);
auto sum_values = reinterpret_cast<S *>(outputs[kSumValuesIdx]->addr);
auto sum_shape = reinterpret_cast<int *>(outputs[kSumShapeIdx]->addr);
const int64_t a_indices_num = inputs[kAIndicesIdx]->size / ((sizeof(T)) * 2);
const int64_t b_indices_num = inputs[kBIndicesIdx]->size / ((sizeof(T)) * 2);
@ -133,7 +138,7 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
break;
case 0:
sum_ab = a_values[i] + b_values[j];
if (thresh_ <= std::abs(sum_ab)) {
if ((*thresh) <= std::abs(sum_ab)) {
whole_indices.emplace_back(true, i);
whole_values.push_back(sum_ab);
}
@ -177,6 +182,10 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
sum_values[num] = whole_values[num];
}
for (size_t num_out = 0; num_out < kNumOfColumn; num_out++) {
sum_shape[num_out] = a_shape[num_out];
}
// Update output shape and type
std::vector<int64_t> out_indices_shape;
std::vector<int64_t> out_values_shape;
@ -189,31 +198,40 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
return true;
}
#define CPU_SPARSE_ADD_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
{ \
KernelAttr() \
.AddInputAttr(ms_index_type) \
.AddInputAttr(ms_value_type) \
.AddInputAttr(ms_index_type) \
.AddInputAttr(ms_value_type) \
.AddOutputAttr(ms_index_type) \
.AddOutputAttr(ms_value_type), \
&SparseAddCpuKernelMod::LaunchKernel<index_type, value_type> \
#define CPU_SPARSE_ADD_KERNEL_REGISTER(ms_index_type, ms_value_type, ms_shape_type, ms_thresh_type, index_type, \
value_type) \
{ \
KernelAttr() \
.AddInputAttr(ms_index_type) \
.AddInputAttr(ms_value_type) \
.AddInputAttr(ms_shape_type) \
.AddInputAttr(ms_index_type) \
.AddInputAttr(ms_value_type) \
.AddInputAttr(ms_shape_type) \
.AddInputAttr(ms_thresh_type) \
.AddOutputAttr(ms_index_type) \
.AddOutputAttr(ms_value_type) \
.AddOutputAttr(ms_shape_type), \
&SparseAddCpuKernelMod::LaunchKernel<index_type, value_type> \
}
const std::vector<std::pair<KernelAttr, SparseAddCpuKernelMod::KernelRunFunc>> &SparseAddCpuKernelMod::GetFuncList()
const {
static const std::vector<std::pair<KernelAttr, SparseAddCpuKernelMod::KernelRunFunc>> func_list = {
// float values
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, int, float),
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, int,
float),
// double values
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, int, double),
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat32, int,
double),
// int values
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int, int),
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, kNumberTypeFloat32, int, int),
// int64 values
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int, int64_t),
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, kNumberTypeFloat32, int,
int64_t),
// int16 values
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, int, int16_t),
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeFloat32, int,
int16_t),
};
return func_list;
}

View File

@ -30,10 +30,6 @@ namespace ops {
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
namespace {
constexpr auto kADenseShape = "a_shape";
constexpr auto kBDenseShape = "b_shape";
constexpr auto kThresh = "thresh";
inline void CheckSparseAddShape(const size_t sparse_shape_size, const size_t expected_dim,
const std::string &arg_name) {
if (sparse_shape_size != expected_dim) {
@ -45,74 +41,43 @@ inline void CheckSparseAddShape(const size_t sparse_shape_size, const size_t exp
inline void CheckSparseAddIndicesDtype(const mindspore::TypePtr dtype, const std::string &arg_name) {
if (!(dtype->equal(mindspore::kInt32))) {
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
<< dtype->ToString() << ".";
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " Int32 but got " << dtype->ToString() << ".";
}
}
} // namespace
void SparseAdd::set_a_dense_shape(const std::vector<int64_t> &shape) {
(void)this->AddAttr(kADenseShape, api::MakeValue(shape));
}
void SparseAdd::set_b_dense_shape(const std::vector<int64_t> &shape) {
(void)this->AddAttr(kBDenseShape, api::MakeValue(shape));
}
void SparseAdd::set_thresh(const float &thresh) { (void)this->AddAttr(kThresh, api::MakeValue(thresh)); }
std::vector<int64_t> SparseAdd::get_a_dense_shape() const {
auto value_ptr = GetAttr(kADenseShape);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> SparseAdd::get_b_dense_shape() const {
auto value_ptr = GetAttr(kBDenseShape);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
float SparseAdd::get_thresh() const {
auto value_ptr = GetAttr(kThresh);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
void SparseAdd::Init(const std::vector<int64_t> &a_shape, const std::vector<int64_t> &b_shape, const float &thresh) {
auto op_name = this->name();
if (a_shape.size() != b_shape.size()) {
MS_LOG(EXCEPTION) << "For " << op_name
<< " the rank of two dense shape should be the same, but got the rank of a_shape is "
<< a_shape.size() << ", and b_shape is " << b_shape.size();
}
if (a_shape != b_shape) {
MS_LOG(EXCEPTION) << "For " << op_name << " two dense shape should be the same, but got a_shape is " << a_shape
<< ", and b_shape is " << b_shape;
}
this->set_a_dense_shape(a_shape);
this->set_b_dense_shape(b_shape);
this->set_thresh(thresh);
}
AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto a_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kADenseShape));
const std::string op_name = primitive->name();
constexpr size_t kAIndicesIdx = 0;
constexpr size_t kAValuesIdx = 1;
constexpr size_t kBIndicesIdx = 2;
constexpr size_t kBValuesIdx = 3;
constexpr size_t kNumOfInputs = 4;
constexpr size_t kAShapeIdx = 2;
constexpr size_t kBIndicesIdx = 3;
constexpr size_t kBValuesIdx = 4;
constexpr size_t kBShapeIdx = 5;
constexpr size_t kThreshIdx = 6;
constexpr size_t kNumOfInputs = 7;
constexpr size_t kIndicesShape = 2;
mindspore::abstract::CheckArgsSize(op_name, input_args, kNumOfInputs);
auto a_indices = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kAIndicesIdx);
auto a_values = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kAValuesIdx);
auto a_shape = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kAShapeIdx);
auto b_indices = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kBIndicesIdx);
auto b_values = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kBValuesIdx);
auto b_shape = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kBShapeIdx);
auto thresh = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kThreshIdx);
MS_EXCEPTION_IF_NULL(a_indices);
MS_EXCEPTION_IF_NULL(a_values);
MS_EXCEPTION_IF_NULL(a_shape);
MS_EXCEPTION_IF_NULL(b_indices);
MS_EXCEPTION_IF_NULL(b_values);
MS_EXCEPTION_IF_NULL(b_shape);
MS_EXCEPTION_IF_NULL(thresh);
// 2-D indices
auto a_indices_shape = a_indices->shape()->shape();
@ -125,6 +90,12 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
CheckSparseAddShape(a_values_shape.size(), 1, "a_values");
CheckSparseAddShape(b_values_shape.size(), 1, "b_values");
auto a_shape_shape = a_shape->shape()->shape();
auto b_shape_shape = b_shape->shape()->shape();
CheckSparseAddShape(a_shape_shape.size(), 1, "a_dense_shape");
CheckSparseAddShape(b_shape_shape.size(), 1, "b_dense_shape");
auto a_shape_type = a_shape->element()->BuildType();
auto a_type = a_values->element()->BuildType();
auto b_type = b_values->element()->BuildType();
// Input a_value and b_value should be the same data type
@ -137,7 +108,7 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
CheckSparseAddIndicesDtype(a_indices->element()->BuildType(), op_name);
CheckSparseAddIndicesDtype(b_indices->element()->BuildType(), op_name);
int64_t max_indices_shape_ = std::min(a_indices_shape[0] + b_indices_shape[0], a_shape[0] * a_shape[1]);
int64_t max_indices_shape_ = a_indices_shape[0] + b_indices_shape[0];
int64_t min_indices_shape_ = std::max(a_indices_shape[0], b_indices_shape[0]);
ShapeVector out_indices_shape{-1, 2};
ShapeVector out_value_shape{-1};
@ -151,8 +122,10 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
std::make_shared<mindspore::abstract::Shape>(out_indices_shape, min_indices_shape, max_indices_shape));
auto out_values = std::make_shared<AbstractTensor>(
a_type, std::make_shared<mindspore::abstract::Shape>(out_value_shape, min_value_shape, max_value_shape));
auto out_shape =
std::make_shared<AbstractTensor>(a_shape_type, std::make_shared<mindspore::abstract::Shape>(a_shape_shape));
AbstractBasePtrList ret = {out_indices, out_values};
AbstractBasePtrList ret = {out_indices, out_values, out_shape};
return std::make_shared<AbstractTuple>(ret);
}
MIND_API_OPERATOR_IMPL(SparseAdd, BaseOperator);

View File

@ -33,30 +33,22 @@ class MIND_API SparseAdd : public BaseOperator {
SparseAdd() : BaseOperator(kNameSparseAdd) {
InitIOName(
{
"a_indices",
"a_values",
"b_indices",
"b_values",
"x1_indices",
"x1_values",
"x1_shape",
"x2_indices",
"x2_values",
"x2_shape",
"thresh",
},
{"sum_indices", "sum_values"});
{"sum_indices", "sum_values", "sum_shape"});
}
/// \brief Init.
/// Refer to the parameters of python API @ref mindspore.ops._csr_ops.SparseAdd for the inputs.
void Init(const std::vector<int64_t> &a_shape, const std::vector<int64_t> &b_shape, const float &thresh);
/// \brief Set dense shape.
void set_a_dense_shape(const std::vector<int64_t> &shape);
void set_b_dense_shape(const std::vector<int64_t> &shape);
void set_thresh(const float &thresh);
/// \brief Get dense shape.
///
/// \return dense shape.
std::vector<int64_t> get_a_dense_shape() const;
std::vector<int64_t> get_b_dense_shape() const;
float get_thresh() const;
void Init() const {}
};
abstract::AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_MATRIX_ADD_H_
#endif // MINDSPORE_CORE_OPS_SPARSE_ADD_H_