[assistant][ops][I48O5I] add CSRSparseMatrixToDense operator
This commit is contained in:
parent
058f6a2577
commit
e7ffa50b3e
|
@ -0,0 +1,202 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/csr_sparse_matrix_to_dense_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kZero = 0;
|
||||
constexpr size_t kOne = 1;
|
||||
constexpr size_t kDefaultRank = 2;
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
constexpr size_t kOutputIndex = 0;
|
||||
constexpr size_t kCSRSparseMatrixToDenseInputsNum = 5;
|
||||
constexpr size_t kCSRSparseMatrixToDenseOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
void CSRSparseMatrixToDenseCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
node_wpt_ = kernel_node;
|
||||
indices_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex0);
|
||||
values_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex4);
|
||||
rank_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex0)[kZero];
|
||||
batch_size_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex1)[kZero] - kOne;
|
||||
}
|
||||
|
||||
bool CSRSparseMatrixToDenseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCSRSparseMatrixToDenseInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCSRSparseMatrixToDenseOutputsNum, kernel_name_);
|
||||
switch (indices_type) {
|
||||
case kNumberTypeInt32:
|
||||
switch (values_type) {
|
||||
case kNumberTypeFloat32:
|
||||
LaunchKernel<int32_t, float>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
LaunchKernel<int32_t, double>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeComplex64:
|
||||
LaunchKernel<int32_t, std::complex<float>>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeComplex128:
|
||||
LaunchKernel<int32_t, std::complex<double>>(inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of values should be "
|
||||
<< "float32, float64, complex64 or complex128, but got "
|
||||
<< TypeIdToType(values_type)->ToString();
|
||||
}
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
switch (values_type) {
|
||||
case kNumberTypeFloat32:
|
||||
LaunchKernel<int64_t, float>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
LaunchKernel<int64_t, double>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeComplex64:
|
||||
LaunchKernel<int64_t, std::complex<float>>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeComplex128:
|
||||
LaunchKernel<int64_t, std::complex<double>>(inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of values should be "
|
||||
<< "float32, float64, complex64 or complex128, but got "
|
||||
<< TypeIdToType(values_type)->ToString();
|
||||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of indices should be int32 or int64, "
|
||||
<< "but got " << TypeIdToType(indices_type)->ToString();
|
||||
}
|
||||
auto node_ = node_wpt_.lock();
|
||||
if (!node_) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
}
|
||||
std::vector<TypeId> y_dtype = {values_type};
|
||||
std::vector<int64_t> y_dims;
|
||||
if (rank_ == kDefaultRank) {
|
||||
y_dims = {SizeToLong(num_rows_), SizeToLong(num_cols_)};
|
||||
} else {
|
||||
y_dims = {SizeToLong(batch_size_), SizeToLong(num_rows_), SizeToLong(num_cols_)};
|
||||
}
|
||||
(void)common::AnfAlgo::SetOutputInferTypeAndShape(y_dtype, {y_dims}, node_.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename indiceT, typename valueT>
|
||||
void CSRSparseMatrixToDenseCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
const size_t shift = (rank_ == kDefaultRank) ? kZero : kOne;
|
||||
num_rows_ = *(static_cast<indiceT *>(inputs[kInputIndex0]->addr) + shift);
|
||||
num_cols_ = *(static_cast<indiceT *>(inputs[kInputIndex0]->addr) + shift + kOne);
|
||||
indiceT *batch_ptrs = static_cast<indiceT *>(inputs[kInputIndex1]->addr);
|
||||
indiceT *row_ptrs = static_cast<indiceT *>(inputs[kInputIndex2]->addr);
|
||||
indiceT *col_ind = static_cast<indiceT *>(inputs[kInputIndex3]->addr);
|
||||
valueT *values = static_cast<valueT *>(inputs[kInputIndex4]->addr);
|
||||
valueT *y_data = static_cast<valueT *>(outputs[kOutputIndex]->addr);
|
||||
for (size_t batch_idx = kZero; batch_idx < batch_size_; batch_idx++) {
|
||||
const size_t dense_offset = batch_idx * num_rows_ * num_cols_;
|
||||
for (size_t i = kZero; i < num_rows_ * num_cols_; ++i) {
|
||||
y_data[dense_offset + i] = valueT(kZero);
|
||||
}
|
||||
const size_t csr_batch_offset = batch_ptrs[batch_idx];
|
||||
for (size_t row_idx = kZero; row_idx < num_rows_; ++row_idx) {
|
||||
const size_t row_offset = batch_idx * (num_rows_ + kOne) + row_idx;
|
||||
const size_t col_begin = row_ptrs[row_offset];
|
||||
const size_t col_end = row_ptrs[row_offset + kOne];
|
||||
for (size_t i = col_begin; i < col_end; ++i) {
|
||||
const size_t col_idx = col_ind[csr_batch_offset + i];
|
||||
y_data[dense_offset + (row_idx * num_cols_) + col_idx] = values[csr_batch_offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CSRSparseMatrixToDenseCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CSRSparseMatrixToDense, CSRSparseMatrixToDenseCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CSR_SPARSE_MATRIX_TO_DENSE_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CSR_SPARSE_MATRIX_TO_DENSE_KERNEL_H_
|
||||
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CSRSparseMatrixToDenseCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
CSRSparseMatrixToDenseCpuKernelMod() = default;
|
||||
~CSRSparseMatrixToDenseCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename valueT, typename indiceT>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
size_t rank_{0};
|
||||
size_t batch_size_{0};
|
||||
size_t num_rows_{0};
|
||||
size_t num_cols_{0};
|
||||
TypeId values_type{kTypeUnknown};
|
||||
TypeId indices_type{kTypeUnknown};
|
||||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CSR_SPARSE_MATRIX_TO_DENSE_KERNEL_H_
|
|
@ -885,6 +885,7 @@ GVAR_DEF(PrimitivePtr, kPrimDenseToCSRSparseMatrix, std::make_shared<Primitive>(
|
|||
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToSparseTensor, std::make_shared<Primitive>(kCSRSparseMatrixToSparseTensor));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseConcat, std::make_shared<Primitive>(kSparseConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixNNZ, std::make_shared<Primitive>(kSparseMatrixNNZ));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToDense, std::make_shared<Primitive>("CSRSparseMatrixToDense"));
|
||||
|
||||
// Sparse Grad ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared<Primitive>(kSparseAddGrad));
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_sparse_matrix_to_dense.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto d_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto b_ptrs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto r_ptrs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto c_ind_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
const int64_t kZero = 0, kOne = 1, kDefalutRank = 2, kBatchRank = 3;
|
||||
const int64_t rank = d_shape_shape[kZero];
|
||||
if (d_shape_shape.size() != kOne || c_ind_shape.size() != kOne || values_shape.size() != kOne ||
|
||||
r_ptrs_shape.size() != kOne || b_ptrs_shape.size() != kOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', each input should be 1-D, but got "
|
||||
<< "'x_dense_shape' rank " << d_shape_shape.size() << ", 'x_batch_pointers' rank "
|
||||
<< b_ptrs_shape.size() << ", 'x_row_pointers' rank " << r_ptrs_shape.size()
|
||||
<< ", 'x_col_indices' rank " << c_ind_shape.size() << ", 'x_values' rank "
|
||||
<< values_shape.size() << ".";
|
||||
}
|
||||
if (rank != kDefalutRank && rank != kBatchRank) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', dense form of the input "
|
||||
<< "should have rank 2 or 3, but got " << d_shape_shape[kZero] << ".";
|
||||
}
|
||||
if (values_shape[kZero] != c_ind_shape[kZero]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'col_indices' and 'values' "
|
||||
<< "should have the same length.";
|
||||
}
|
||||
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
|
||||
!input_args[kInputIndex0]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex0]->BuildValue()->isa<None>()) {
|
||||
ShapeVector y_shape;
|
||||
auto d_shape_value = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(d_shape_value);
|
||||
auto d_shape_value_ptr = d_shape_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(d_shape_value_ptr);
|
||||
auto d_shape_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("x_dense_shape", d_shape_value_ptr, primitive->name());
|
||||
for (int64_t i = kZero; i < rank; i++) {
|
||||
if (d_shape_value_ptr_tensor[i] <= kZero) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', each element of 'x_dense_shape' must be greater than 0.";
|
||||
}
|
||||
}
|
||||
int64_t batch_size = kOne;
|
||||
int64_t row_num = d_shape_value_ptr_tensor[kZero];
|
||||
if (rank == kBatchRank) {
|
||||
batch_size = d_shape_value_ptr_tensor[kZero], row_num = d_shape_value_ptr_tensor[kOne];
|
||||
}
|
||||
if (b_ptrs_shape[kZero] != (batch_size + kOne) || r_ptrs_shape[kZero] != batch_size * (row_num + kOne)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', batch size of the input is " << batch_size
|
||||
<< ", row numbers of the input is " << row_num << ", so shape of 'x_batch_pointers' "
|
||||
<< "should be (" << batch_size + kOne << "), but got (" << b_ptrs_shape[kZero] << ")"
|
||||
<< ", shape of 'x_row_pointers' should be (" << batch_size * (row_num + kOne) << "), "
|
||||
<< "but got (" << r_ptrs_shape[kZero] << ").";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
} else {
|
||||
ShapeVector dense_shape = {-2};
|
||||
ShapeVector infer_shape_min;
|
||||
ShapeVector infer_shape_max;
|
||||
infer_shape_min = infer_shape_max = {1};
|
||||
return std::make_shared<abstract::Shape>(dense_shape, infer_shape_min, infer_shape_max);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr CSRSparseMatrixToDenseInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
const std::set<TypePtr> valid_values_types = {kFloat64, kFloat32, kComplex128, kComplex64};
|
||||
const std::set<TypePtr> valid_indices_types = {kInt32, kInt64};
|
||||
std::map<std::string, TypePtr> indices_args;
|
||||
(void)indices_args.emplace("x_dense_shape", input_args[kInputIndex0]->BuildType());
|
||||
(void)indices_args.emplace("x_batch_pointers", input_args[kInputIndex1]->BuildType());
|
||||
(void)indices_args.emplace("x_row_pointers", input_args[kInputIndex2]->BuildType());
|
||||
(void)indices_args.emplace("x_col_indices", input_args[kInputIndex3]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(indices_args, valid_indices_types, op_name);
|
||||
auto values_type = input_args[kInputIndex4]->BuildType();
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x_values", values_type, valid_values_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(CSRSparseMatrixToDense, BaseOperator);
|
||||
AbstractBasePtr CSRSparseMatrixToDenseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 5;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto types = CSRSparseMatrixToDenseInferType(primitive, input_args);
|
||||
auto shapes = CSRSparseMatrixToDenseInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameCSRSparseMatrixToDense, {0});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRSparseMatrixToDense, prim::kPrimCSRSparseMatrixToDense, CSRSparseMatrixToDenseInfer,
|
||||
nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_SPARSE_MATRIX_TO_DENSE
|
||||
#define MINDSPORE_CORE_OPS_CSR_SPARSE_MATRIX_TO_DENSE
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRSparseMatrixToDense = "CSRSparseMatrixToDense";
|
||||
/// \brief Converts a CSR sparse matrix to its dense form.
|
||||
/// Refer to Python API @ref mindspore.ops.CSRSparseMatrixToDense for more details.
|
||||
class MIND_API CSRSparseMatrixToDense : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRSparseMatrixToDense);
|
||||
/// \brief Constructor.
|
||||
CSRSparseMatrixToDense() : BaseOperator(kNameCSRSparseMatrixToDense) {
|
||||
InitIOName({"x_dense_shape", "x_batch_pointers", "x_row_pointers", "x_col_indices", "x_values"}, {"y"});
|
||||
}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRSparseMatrixToDenseInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_SPARSE_MATRIX_TO_DENSE
|
|
@ -269,3 +269,4 @@ from .multinomial import _multinomial_aicpu
|
|||
from .pow import _pow_aicpu
|
||||
from .depth_to_space import _depth_to_space_aicpu
|
||||
from .space_to_depth import _space_to_depth_aicpu
|
||||
from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CSRSparseMatrixToDense op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
csr_sparse_matrix_to_dense_op_info = AiCPURegOp("CSRSparseMatrixToDense") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x_dense_shape", "required") \
|
||||
.input(1, "x_batch_pointers", "required") \
|
||||
.input(2, "x_row_pointers", "required") \
|
||||
.input(3, "x_col_indices", "required") \
|
||||
.input(4, "x_values", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.C128_Default, DataType.C128_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(csr_sparse_matrix_to_dense_op_info)
|
||||
def _csr_sparse_matrix_to_dense_aicpu():
|
||||
"""CSRSparseMatrixToDense AiCPU register"""
|
||||
return
|
|
@ -691,3 +691,61 @@ class SparseAdd(Primitive):
|
|||
self.init_prim_io_names(
|
||||
inputs=["x1_indices", "x1_values", "x1_shape", "x2_indices", "x2_values", "x2_shape", "thresh"],
|
||||
outputs=["sum_indices", "sum_values", "sum_shape"])
|
||||
|
||||
|
||||
class CSRSparseMatrixToDense(Primitive):
|
||||
"""
|
||||
Converts a CSR sparse matrix(maybe batched) to its dense form.
|
||||
|
||||
Note:
|
||||
It is assumed that all the inputs can form a legal CSR sparse matrix, otherwise this operator won't work.
|
||||
|
||||
Inputs:
|
||||
- **x_dense_shape** (Tensor) - A 1-D Tensor. It represents the dense form shape of
|
||||
the input CSR sparse matrix, the shape of which should be :math:`(2,)` or :math:`(3,)`.
|
||||
- **x_batch_pointers** (Tensor) - A 1-D Tensor. Supposing the input CSR sparse matrix is of
|
||||
batch size `n`, it should have shape :math:`(n+1,)`, while the `i`-th element of which stores
|
||||
acummulated counts of nonzero values of the first `i - 1` batches.
|
||||
- **x_row_pointers** (Tensor) - A 1-D Tensor. Supposing the input CSR sparse matrix is of
|
||||
batch size `n` and row number `m`, it can be divided into `n` parts, each part of length
|
||||
`m + 1`. The `i`-th element of each :math:`(m+1,)` vector stores acummulated counts of
|
||||
nonzero values of the first `i - 1` rows in the corresponding batch.
|
||||
- **x_col_indices** (Tensor) - A 1-D Tensor. It represents column indices of the nonzero values
|
||||
in the input CSR sparse matrix.
|
||||
- **x_values** (Tensor) - A 1-D Tensor. It represents all the nonzero values in the
|
||||
input CSR sparse matrix.
|
||||
|
||||
Outputs:
|
||||
Tensor, which is the dense form of the input CSR sparse matrix.
|
||||
Its dtype is the same as `x_values`.
|
||||
|
||||
Raises:
|
||||
TypeError: If the dtype of `x_dense_shape`, `x_batch_pointers`, `x_row_pointers` or `x_col_indices`
|
||||
is not int32 or int64, or the dtypes of above inputs are not the same.
|
||||
TypeError: If the dtype of `x_values` is not float32, float64, complex64 or complex128.
|
||||
TypeError: If any of the inputs is not a tensor.
|
||||
ValueError: If any of the inputs is not 1-D.
|
||||
ValueError: If shape[0] of `x_dense_shape` is not 2 or 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> dense_shape = Tensor([2, 2], dtype=ms.int32)
|
||||
>>> batch_pointers = Tensor([0, 1], dtype=ms.int32)
|
||||
>>> row_pointers = Tensor([0, 1, 1], dtype=ms.int32)
|
||||
>>> col_indices = Tensor([1], dtype=ms.int32)
|
||||
>>> values = Tensor([1.], dtype=ms.float32)
|
||||
>>> csr_to_dense = ops.CSRSparseMatrixToDense()
|
||||
>>> out = csr_to_dense(dense_shape, batch_pointers, row_pointers, col_indices, values)
|
||||
>>> print(out)
|
||||
[[0. 1.]
|
||||
[0. 0.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize CSRSparseMatrixToDense"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'],
|
||||
outputs=['y'])
|
||||
|
|
|
@ -102,6 +102,7 @@ from mindspore.ops.operations.nn_ops import MaxPoolV1
|
|||
from mindspore.ops.operations.array_ops import NonZero
|
||||
from mindspore.ops.operations._grad_ops import MaxPoolGradV1
|
||||
from mindspore.ops.operations.nn_ops import ReLUV3
|
||||
from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToDense
|
||||
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm
|
||||
from mindspore.ops.operations.sparse_ops import SparseTensorDenseMatmul
|
||||
from mindspore.ops.operations.sparse_ops import SparseMatrixNNZ
|
||||
|
@ -1699,6 +1700,14 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[1.0, 0.0], [0.0, 1.0]]).astype(np.float32)),
|
||||
Tensor(np.array([[5.0, 2.0], [3.0, 5.0]]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[3.0, 5.0], [5.0, 7.0]]).astype(np.float32))]}),
|
||||
('CSRSparseMatrixToDense', {
|
||||
'block': CSRSparseMatrixToDense(),
|
||||
'desc_inputs': [Tensor(np.array([2, 2, 2]).astype(np.int64)),
|
||||
Tensor(np.array([0, 2, 4]).astype(np.int64)),
|
||||
Tensor(np.array([0, 1, 2, 0, 1, 2]).astype(np.int64)),
|
||||
Tensor(np.array([0, 1, 0, 1]).astype(np.int64)),
|
||||
Tensor(np.array([5, 2, 3, 5]).astype(np.float64))],
|
||||
'skip': ['backward']}),
|
||||
('DenseToCSRSparseMatrix', {
|
||||
'block': DenseToCSRSparseMatrix(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 0], [0, 1]]).astype(np.float32)),
|
||||
|
|
Loading…
Reference in New Issue