forked from mindspore-Ecosystem/mindspore
!36418 Modify input of sparse add
Merge pull request !36418 from YijieChen/sparse_add_dev
This commit is contained in:
commit
3681ab0ecb
|
@ -15,45 +15,48 @@
|
|||
*/
|
||||
#include "plugin/device/cpu/kernel/sparse_add_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "include/common/thread_pool.h"
|
||||
#include "mindspore/core/ops/sparse_add.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
// Value check constant
|
||||
constexpr size_t kInputNum = 4;
|
||||
constexpr size_t kOutputNum = 2;
|
||||
constexpr size_t kInputNum = 7;
|
||||
constexpr size_t kOutputNum = 3;
|
||||
constexpr size_t kNumOfColumn = 2;
|
||||
// Input idx constant
|
||||
constexpr size_t kAIndicesIdx = 0;
|
||||
constexpr size_t kAValuesIdx = 1;
|
||||
constexpr size_t kBIndicesIdx = 2;
|
||||
constexpr size_t kBValuesIdx = 3;
|
||||
constexpr size_t kAShapeIdx = 2;
|
||||
constexpr size_t kBIndicesIdx = 3;
|
||||
constexpr size_t kBValuesIdx = 4;
|
||||
constexpr size_t kBShapeIdx = 5;
|
||||
constexpr size_t kThreshIdx = 6;
|
||||
// Output idx constant
|
||||
constexpr size_t kSumIndicesIdx = 0;
|
||||
constexpr size_t kSumValuesIdx = 1;
|
||||
constexpr size_t kSumShapeIdx = 2;
|
||||
|
||||
bool SparseAddCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
outputs_ = outputs;
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseAdd>(base_operator);
|
||||
thresh_ = kernel_ptr->get_thresh();
|
||||
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num != kInputNum) {
|
||||
MS_LOG(ERROR) << "For " << kernel_name_ << ", input should be a_indices, a_values, b_indices and b_values total "
|
||||
MS_LOG(ERROR) << "For " << kernel_name_
|
||||
<< ", input should be a_indices, a_values, a_shape, b_indices, b_values, b_shape and thresh total "
|
||||
<< kInputNum << " tensors, but get " << input_num;
|
||||
return false;
|
||||
}
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
auto dense_shape = kernel_ptr->get_a_dense_shape();
|
||||
row_ = LongToSize(dense_shape[0]);
|
||||
dense_size_ = row_ * LongToSize(dense_shape[1]) * GetTypeByte(TypeIdToType(types_[1]));
|
||||
|
||||
is_need_retrieve_output_shape_ = true;
|
||||
for (size_t i = 0; i < kOutputNum; i++) {
|
||||
auto dtype = inputs[i]->GetDtype();
|
||||
|
@ -71,11 +74,10 @@ int SparseAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
MS_LOG(ERROR) << "Input size list should be " << kInputNum << ", but got " << input_size_list_.size();
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
auto max_indices_out_size =
|
||||
std::min(input_size_list_[kAIndicesIdx] + input_size_list_[kBIndicesIdx], dense_size_ * 2);
|
||||
auto max_value_out_size = std::min(input_size_list_[kAValuesIdx] + input_size_list_[kBValuesIdx], dense_size_);
|
||||
output_size_list_.emplace_back(max_indices_out_size);
|
||||
output_size_list_.emplace_back(max_value_out_size);
|
||||
auto max_indices_out_size = input_size_list_[kAIndicesIdx] + input_size_list_[kBIndicesIdx];
|
||||
auto max_value_out_size = input_size_list_[kAValuesIdx] + input_size_list_[kBValuesIdx];
|
||||
output_size_list_[kSumIndicesIdx] = max_indices_out_size;
|
||||
output_size_list_[kSumValuesIdx] = max_value_out_size;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -109,11 +111,14 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
|
|||
// Inputs
|
||||
const auto a_indices = reinterpret_cast<T *>(inputs[kAIndicesIdx]->addr);
|
||||
const auto a_values = reinterpret_cast<S *>(inputs[kAValuesIdx]->addr);
|
||||
const auto a_shape = reinterpret_cast<int *>(inputs[kAShapeIdx]->addr);
|
||||
const auto b_indices = reinterpret_cast<T *>(inputs[kBIndicesIdx]->addr);
|
||||
const auto b_values = reinterpret_cast<S *>(inputs[kBValuesIdx]->addr);
|
||||
const auto thresh = reinterpret_cast<float *>(inputs[kThreshIdx]->addr);
|
||||
// Outputs
|
||||
auto sum_indices = reinterpret_cast<T *>(outputs[kSumIndicesIdx]->addr);
|
||||
auto sum_values = reinterpret_cast<S *>(outputs[kSumValuesIdx]->addr);
|
||||
auto sum_shape = reinterpret_cast<int *>(outputs[kSumShapeIdx]->addr);
|
||||
|
||||
const int64_t a_indices_num = inputs[kAIndicesIdx]->size / ((sizeof(T)) * 2);
|
||||
const int64_t b_indices_num = inputs[kBIndicesIdx]->size / ((sizeof(T)) * 2);
|
||||
|
@ -133,7 +138,7 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
|
|||
break;
|
||||
case 0:
|
||||
sum_ab = a_values[i] + b_values[j];
|
||||
if (thresh_ <= std::abs(sum_ab)) {
|
||||
if ((*thresh) <= std::abs(sum_ab)) {
|
||||
whole_indices.emplace_back(true, i);
|
||||
whole_values.push_back(sum_ab);
|
||||
}
|
||||
|
@ -177,6 +182,10 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
|
|||
sum_values[num] = whole_values[num];
|
||||
}
|
||||
|
||||
for (size_t num_out = 0; num_out < kNumOfColumn; num_out++) {
|
||||
sum_shape[num_out] = a_shape[num_out];
|
||||
}
|
||||
|
||||
// Update output shape and type
|
||||
std::vector<int64_t> out_indices_shape;
|
||||
std::vector<int64_t> out_values_shape;
|
||||
|
@ -189,31 +198,40 @@ bool SparseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &
|
|||
return true;
|
||||
}
|
||||
|
||||
#define CPU_SPARSE_ADD_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
|
||||
{ \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddInputAttr(ms_value_type) \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddInputAttr(ms_value_type) \
|
||||
.AddOutputAttr(ms_index_type) \
|
||||
.AddOutputAttr(ms_value_type), \
|
||||
&SparseAddCpuKernelMod::LaunchKernel<index_type, value_type> \
|
||||
#define CPU_SPARSE_ADD_KERNEL_REGISTER(ms_index_type, ms_value_type, ms_shape_type, ms_thresh_type, index_type, \
|
||||
value_type) \
|
||||
{ \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddInputAttr(ms_value_type) \
|
||||
.AddInputAttr(ms_shape_type) \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddInputAttr(ms_value_type) \
|
||||
.AddInputAttr(ms_shape_type) \
|
||||
.AddInputAttr(ms_thresh_type) \
|
||||
.AddOutputAttr(ms_index_type) \
|
||||
.AddOutputAttr(ms_value_type) \
|
||||
.AddOutputAttr(ms_shape_type), \
|
||||
&SparseAddCpuKernelMod::LaunchKernel<index_type, value_type> \
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, SparseAddCpuKernelMod::KernelRunFunc>> &SparseAddCpuKernelMod::GetFuncList()
|
||||
const {
|
||||
static const std::vector<std::pair<KernelAttr, SparseAddCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
// float values
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, int, float),
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, int,
|
||||
float),
|
||||
// double values
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, int, double),
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat32, int,
|
||||
double),
|
||||
// int values
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int, int),
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, kNumberTypeFloat32, int, int),
|
||||
// int64 values
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int, int64_t),
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, kNumberTypeFloat32, int,
|
||||
int64_t),
|
||||
// int16 values
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, int, int16_t),
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeFloat32, int,
|
||||
int16_t),
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
|
|
@ -30,10 +30,6 @@ namespace ops {
|
|||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
namespace {
|
||||
constexpr auto kADenseShape = "a_shape";
|
||||
constexpr auto kBDenseShape = "b_shape";
|
||||
constexpr auto kThresh = "thresh";
|
||||
|
||||
inline void CheckSparseAddShape(const size_t sparse_shape_size, const size_t expected_dim,
|
||||
const std::string &arg_name) {
|
||||
if (sparse_shape_size != expected_dim) {
|
||||
|
@ -45,74 +41,43 @@ inline void CheckSparseAddShape(const size_t sparse_shape_size, const size_t exp
|
|||
|
||||
inline void CheckSparseAddIndicesDtype(const mindspore::TypePtr dtype, const std::string &arg_name) {
|
||||
if (!(dtype->equal(mindspore::kInt32))) {
|
||||
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
|
||||
<< dtype->ToString() << ".";
|
||||
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " Int32 but got " << dtype->ToString() << ".";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SparseAdd::set_a_dense_shape(const std::vector<int64_t> &shape) {
|
||||
(void)this->AddAttr(kADenseShape, api::MakeValue(shape));
|
||||
}
|
||||
void SparseAdd::set_b_dense_shape(const std::vector<int64_t> &shape) {
|
||||
(void)this->AddAttr(kBDenseShape, api::MakeValue(shape));
|
||||
}
|
||||
void SparseAdd::set_thresh(const float &thresh) { (void)this->AddAttr(kThresh, api::MakeValue(thresh)); }
|
||||
std::vector<int64_t> SparseAdd::get_a_dense_shape() const {
|
||||
auto value_ptr = GetAttr(kADenseShape);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
std::vector<int64_t> SparseAdd::get_b_dense_shape() const {
|
||||
auto value_ptr = GetAttr(kBDenseShape);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
float SparseAdd::get_thresh() const {
|
||||
auto value_ptr = GetAttr(kThresh);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void SparseAdd::Init(const std::vector<int64_t> &a_shape, const std::vector<int64_t> &b_shape, const float &thresh) {
|
||||
auto op_name = this->name();
|
||||
if (a_shape.size() != b_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "For " << op_name
|
||||
<< " the rank of two dense shape should be the same, but got the rank of a_shape is "
|
||||
<< a_shape.size() << ", and b_shape is " << b_shape.size();
|
||||
}
|
||||
if (a_shape != b_shape) {
|
||||
MS_LOG(EXCEPTION) << "For " << op_name << " two dense shape should be the same, but got a_shape is " << a_shape
|
||||
<< ", and b_shape is " << b_shape;
|
||||
}
|
||||
this->set_a_dense_shape(a_shape);
|
||||
this->set_b_dense_shape(b_shape);
|
||||
this->set_thresh(thresh);
|
||||
}
|
||||
|
||||
AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto a_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kADenseShape));
|
||||
const std::string op_name = primitive->name();
|
||||
|
||||
constexpr size_t kAIndicesIdx = 0;
|
||||
constexpr size_t kAValuesIdx = 1;
|
||||
constexpr size_t kBIndicesIdx = 2;
|
||||
constexpr size_t kBValuesIdx = 3;
|
||||
constexpr size_t kNumOfInputs = 4;
|
||||
constexpr size_t kAShapeIdx = 2;
|
||||
constexpr size_t kBIndicesIdx = 3;
|
||||
constexpr size_t kBValuesIdx = 4;
|
||||
constexpr size_t kBShapeIdx = 5;
|
||||
constexpr size_t kThreshIdx = 6;
|
||||
|
||||
constexpr size_t kNumOfInputs = 7;
|
||||
constexpr size_t kIndicesShape = 2;
|
||||
|
||||
mindspore::abstract::CheckArgsSize(op_name, input_args, kNumOfInputs);
|
||||
auto a_indices = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kAIndicesIdx);
|
||||
auto a_values = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kAValuesIdx);
|
||||
auto a_shape = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kAShapeIdx);
|
||||
auto b_indices = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kBIndicesIdx);
|
||||
auto b_values = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kBValuesIdx);
|
||||
auto b_shape = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kBShapeIdx);
|
||||
auto thresh = mindspore::abstract::CheckArg<AbstractTensor>(op_name, input_args, kThreshIdx);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(a_indices);
|
||||
MS_EXCEPTION_IF_NULL(a_values);
|
||||
MS_EXCEPTION_IF_NULL(a_shape);
|
||||
MS_EXCEPTION_IF_NULL(b_indices);
|
||||
MS_EXCEPTION_IF_NULL(b_values);
|
||||
MS_EXCEPTION_IF_NULL(b_shape);
|
||||
MS_EXCEPTION_IF_NULL(thresh);
|
||||
|
||||
// 2-D indices
|
||||
auto a_indices_shape = a_indices->shape()->shape();
|
||||
|
@ -125,6 +90,12 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
CheckSparseAddShape(a_values_shape.size(), 1, "a_values");
|
||||
CheckSparseAddShape(b_values_shape.size(), 1, "b_values");
|
||||
|
||||
auto a_shape_shape = a_shape->shape()->shape();
|
||||
auto b_shape_shape = b_shape->shape()->shape();
|
||||
CheckSparseAddShape(a_shape_shape.size(), 1, "a_dense_shape");
|
||||
CheckSparseAddShape(b_shape_shape.size(), 1, "b_dense_shape");
|
||||
auto a_shape_type = a_shape->element()->BuildType();
|
||||
|
||||
auto a_type = a_values->element()->BuildType();
|
||||
auto b_type = b_values->element()->BuildType();
|
||||
// Input a_value and b_value should be the same data type
|
||||
|
@ -137,7 +108,7 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
CheckSparseAddIndicesDtype(a_indices->element()->BuildType(), op_name);
|
||||
CheckSparseAddIndicesDtype(b_indices->element()->BuildType(), op_name);
|
||||
|
||||
int64_t max_indices_shape_ = std::min(a_indices_shape[0] + b_indices_shape[0], a_shape[0] * a_shape[1]);
|
||||
int64_t max_indices_shape_ = a_indices_shape[0] + b_indices_shape[0];
|
||||
int64_t min_indices_shape_ = std::max(a_indices_shape[0], b_indices_shape[0]);
|
||||
ShapeVector out_indices_shape{-1, 2};
|
||||
ShapeVector out_value_shape{-1};
|
||||
|
@ -151,8 +122,10 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
std::make_shared<mindspore::abstract::Shape>(out_indices_shape, min_indices_shape, max_indices_shape));
|
||||
auto out_values = std::make_shared<AbstractTensor>(
|
||||
a_type, std::make_shared<mindspore::abstract::Shape>(out_value_shape, min_value_shape, max_value_shape));
|
||||
auto out_shape =
|
||||
std::make_shared<AbstractTensor>(a_shape_type, std::make_shared<mindspore::abstract::Shape>(a_shape_shape));
|
||||
|
||||
AbstractBasePtrList ret = {out_indices, out_values};
|
||||
AbstractBasePtrList ret = {out_indices, out_values, out_shape};
|
||||
return std::make_shared<AbstractTuple>(ret);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(SparseAdd, BaseOperator);
|
||||
|
|
|
@ -33,30 +33,22 @@ class MIND_API SparseAdd : public BaseOperator {
|
|||
SparseAdd() : BaseOperator(kNameSparseAdd) {
|
||||
InitIOName(
|
||||
{
|
||||
"a_indices",
|
||||
"a_values",
|
||||
"b_indices",
|
||||
"b_values",
|
||||
"x1_indices",
|
||||
"x1_values",
|
||||
"x1_shape",
|
||||
"x2_indices",
|
||||
"x2_values",
|
||||
"x2_shape",
|
||||
"thresh",
|
||||
},
|
||||
{"sum_indices", "sum_values"});
|
||||
{"sum_indices", "sum_values", "sum_shape"});
|
||||
}
|
||||
/// \brief Init.
|
||||
/// Refer to the parameters of python API @ref mindspore.ops._csr_ops.SparseAdd for the inputs.
|
||||
void Init(const std::vector<int64_t> &a_shape, const std::vector<int64_t> &b_shape, const float &thresh);
|
||||
/// \brief Set dense shape.
|
||||
void set_a_dense_shape(const std::vector<int64_t> &shape);
|
||||
void set_b_dense_shape(const std::vector<int64_t> &shape);
|
||||
void set_thresh(const float &thresh);
|
||||
/// \brief Get dense shape.
|
||||
///
|
||||
/// \return dense shape.
|
||||
std::vector<int64_t> get_a_dense_shape() const;
|
||||
std::vector<int64_t> get_b_dense_shape() const;
|
||||
float get_thresh() const;
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_MATRIX_ADD_H_
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_ADD_H_
|
||||
|
|
Loading…
Reference in New Issue