!32175 [feat][assistant][I48O92]Add SparseMatrixTranspose

Merge pull request !32175 from 李定维/SparseMatrixTranspose
This commit is contained in:
i-robot 2022-07-20 01:44:31 +00:00 committed by Gitee
commit bd82adc603
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 929 additions and 2 deletions

View File

@ -0,0 +1,410 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int64_t kTwo = 2;
constexpr int64_t kInputsNum = 5;
constexpr int64_t kOutputsNum = 5;
constexpr int64_t kInputIndex0 = 0;
constexpr int64_t kInputIndex1 = 1;
constexpr int64_t kInputIndex2 = 2;
constexpr int64_t kInputIndex3 = 3;
constexpr int64_t kInputIndex4 = 4;
constexpr int64_t kOutputIndex0 = 0;
constexpr int64_t kOutputIndex1 = 1;
constexpr int64_t kOutputIndex2 = 2;
constexpr int64_t kOutputIndex3 = 3;
constexpr int64_t kOutputIndex4 = 4;
constexpr int64_t kDenseShape0 = 0;
constexpr int64_t kDenseShape1 = 1;
constexpr int64_t kDenseShape2 = 2;
constexpr int64_t kRankWithOutBatch = 2;
constexpr int64_t kRankWithBatch = 3;
KernelAttr AddKernel(const TypeId &ms_type1, const TypeId &ms_type2, const TypeId &ms_type3, const TypeId &ms_type4,
const TypeId &ms_type5, const TypeId &ms_type6, const TypeId &ms_type7, const TypeId &ms_type8,
const TypeId &ms_type9, const TypeId &ms_type10) {
auto kernel = KernelAttr()
.AddInputAttr(ms_type1)
.AddInputAttr(ms_type2)
.AddInputAttr(ms_type3)
.AddInputAttr(ms_type4)
.AddInputAttr(ms_type5)
.AddOutputAttr(ms_type6)
.AddOutputAttr(ms_type7)
.AddOutputAttr(ms_type8)
.AddOutputAttr(ms_type9)
.AddOutputAttr(ms_type10);
return kernel;
}
#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
AddKernel(kNumberType##t1, kNumberType##t2, kNumberType##t3, kNumberType##t4, kNumberType##t5, kNumberType##t6, \
kNumberType##t7, kNumberType##t8, kNumberType##t9, kNumberType##t10)
#define SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(DTYPE, VTYPEONE, VTYPETWO, inputs, outputs) \
case (DTYPE): { \
LaunchKernel<VTYPEONE, VTYPETWO>(inputs, outputs); \
break; \
}
#define SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(DTYPE, VTYPEONE, VTYPETWO, inputs, outputs) \
case (DTYPE): { \
LaunchcomplexKernel<VTYPEONE, VTYPETWO>(inputs, outputs); \
break; \
}
#define NODE_CHECK_AND_OUTPUT_TYPE(node_) \
do { \
if (!node_) { \
MS_LOG(EXCEPTION) << "node_wpt_ is expired."; \
} \
} while (0);
#define SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape) \
common::AnfAlgo::SetOutputInferTypeAndShape( \
dtypes, \
{common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex0), \
common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex1), y_row_pointers_shape, \
common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex3), \
common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex4)}, \
node_.get());
#define BATCH_CHECK(batch, batch_pointers, kernel_name_) \
do { \
if (batch + 1 != batch_pointers) { \
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals" \
<< "dense shape[0] + 1 to match the CSR form input when input has batch."; \
} \
} while (0);
} // namespace
void SparseMatrixTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
int64_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
int64_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_);
std::vector<int64_t> input_dense_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex0);
rank_x_ = input_dense_shape[0];
if (rank_x_ != kRankWithOutBatch && rank_x_ != kRankWithBatch) {
MS_LOG(EXCEPTION) << "For SparseMatrixTranspose,the rank must be 2 or 3, but got" << rank_x_ << "!";
}
conjugate = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "conjugate");
std::vector<int64_t> x_batch_pointers_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex1);
std::vector<int64_t> x_row_pointers_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex2);
std::vector<int64_t> x_col_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex3);
std::vector<int64_t> x_value_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex4);
x_value_size_ = x_value_shape[0];
x_batch_pointers_size_ = x_batch_pointers_shape[0];
x_col_indice_size_ = x_col_indices_shape[0];
x_row_pointer_size_ = x_row_pointers_shape[0];
if (x_col_indice_size_ != x_value_size_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of col indice shape should equals "
<< "values shape to match the CSR form input.";
}
node_wpt_ = kernel_node;
indiceT_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex0);
valueT_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex4);
}
bool SparseMatrixTransposeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (indiceT_) {
case kNumberTypeInt32:
switch (valueT_) {
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt8, int32_t, int8_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt8, int32_t, uint8_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt16, int32_t, int16_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt16, int32_t, uint16_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt32, int32_t, uint32_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt64, int32_t, int64_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt64, int32_t, uint64_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat16, int32_t, float16, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat32, int32_t, float, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat64, int32_t, double, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex64, int32_t, complex64, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex128, int32_t, complex128, inputs, outputs)
default:
MS_LOG(EXCEPTION) << "data type of values is not included.";
break;
}
break;
case kNumberTypeInt64:
switch (valueT_) {
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt8, int64_t, int8_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt8, int64_t, uint8_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt16, int64_t, int16_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt16, int64_t, uint16_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt32, int64_t, int32_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt32, int64_t, uint32_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt64, int64_t, uint64_t, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat16, int64_t, float16, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat32, int64_t, float, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat64, int64_t, double, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex64, int64_t, complex64, inputs, outputs)
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex128, int64_t, complex128, inputs, outputs)
default:
MS_LOG(EXCEPTION) << "data type of values is not included.";
break;
}
break;
default:
MS_LOG(EXCEPTION) << "The data type of dense_shape, batch_pointers, "
<< "row_pointers, col_indices is not int32 or int64.";
break;
}
return true;
}
template <typename indiceT, typename valueT>
bool SparseMatrixTransposeCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
indiceT *x_dense_shape = static_cast<indiceT *>(inputs[kInputIndex0]->addr);
indiceT *x_batch_pointers = static_cast<indiceT *>(inputs[kInputIndex1]->addr);
indiceT *x_row_pointers = static_cast<indiceT *>(inputs[kInputIndex2]->addr);
indiceT *x_col_indices = static_cast<indiceT *>(inputs[kInputIndex3]->addr);
valueT *x_values = static_cast<valueT *>(inputs[kInputIndex4]->addr);
indiceT *y_dense_shape_addr = static_cast<indiceT *>(outputs[kOutputIndex0]->addr);
indiceT *y_batch_pointers_addr = static_cast<indiceT *>(outputs[kOutputIndex1]->addr);
indiceT *y_row_pointers_addr = static_cast<indiceT *>(outputs[kOutputIndex2]->addr);
indiceT *y_col_indices_addr = static_cast<indiceT *>(outputs[kOutputIndex3]->addr);
valueT *y_values_addr = static_cast<valueT *>(outputs[kOutputIndex4]->addr);
std::vector<int64_t> y_row_pointers_shape;
int64_t batch_pointers = x_batch_pointers_size_;
if (rank_x_ == kRankWithBatch) {
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape0];
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape2];
y_dense_shape_addr[kDenseShape2] = x_dense_shape[kDenseShape1];
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape0] * (x_dense_shape[kDenseShape2] + 1));
int64_t batch = x_dense_shape[kDenseShape0];
BATCH_CHECK(batch, batch_pointers, kernel_name_)
} else {
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape1];
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape0];
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape1] + 1);
if (batch_pointers != kTwo) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals"
<< "2 to match the CSR form input when input has no batch.";
}
}
for (int64_t i = 0; i < batch_pointers; ++i) {
y_batch_pointers_addr[i] = x_batch_pointers[i];
}
int64_t num_rows = x_dense_shape[rank_x_ - kTwo];
int64_t num_cols = x_dense_shape[rank_x_ - 1];
int64_t num_batch = batch_pointers - 1;
std::vector<int64_t> y_part_row_pointers(num_cols + 1);
std::vector<int64_t> part_row_pointers(num_rows + 1);
if (x_row_pointer_size_ != num_batch * (num_rows + 1)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of row pionters shape should equals"
<< " batch*(rows + 1 ) to match the CSR form input when input has no batch.";
}
for (int64_t j = 0; j < num_batch; ++j) {
int64_t n = x_batch_pointers[j + 1] - x_batch_pointers[j];
std::vector<valueT> part_values(n);
std::vector<indiceT> part_col_indices(n);
std::vector<indiceT> y_part_col_indices(n);
std::vector<valueT> y_part_values(n);
for (int64_t i = 0; i < num_cols + 1; ++i) {
y_part_row_pointers[i] = 0;
}
for (int64_t k = 0; k < num_rows + 1; ++k) {
part_row_pointers[k] = x_row_pointers[(num_rows + 1) * j + k];
}
for (int64_t k = 0; k < n; ++k) {
part_values[k] = x_values[x_batch_pointers[j] + k];
part_col_indices[k] = x_col_indices[x_batch_pointers[j] + k];
}
for (int64_t i = 0; i < n; ++i) {
y_part_row_pointers[part_col_indices[i] + 1] += 1;
}
for (int64_t i = 1; i < num_cols + 1; ++i) {
y_part_row_pointers[i] += y_part_row_pointers[i - 1];
}
for (int64_t k = 0; k < num_cols + 1; ++k) {
y_row_pointers_addr[(num_cols + 1) * j + k] = y_part_row_pointers[k];
}
std::vector<int64_t> current_col_count(num_cols);
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
const int64_t row_begin = part_row_pointers[row_idx];
const int64_t row_end = part_row_pointers[row_idx + 1];
for (int64_t i = row_begin; i < row_end; ++i) {
const int64_t col_idx = part_col_indices[i];
const int64_t offset = y_part_row_pointers[col_idx] + current_col_count[col_idx];
y_part_col_indices[offset] = row_idx;
y_part_values[offset] = part_values[i];
current_col_count[col_idx] += 1;
}
}
for (int64_t k = 0; k < n; ++k) {
y_values_addr[x_batch_pointers[j] + k] = y_part_values[k];
y_col_indices_addr[x_batch_pointers[j] + k] = y_part_col_indices[k];
}
}
auto node_ = node_wpt_.lock();
NODE_CHECK_AND_OUTPUT_TYPE(node_)
int64_t output_nm = common::AnfAlgo::GetOutputTensorNum(node_);
std::vector<TypeId> dtypes(output_nm);
for (int64_t i = 0; i < output_nm; i++) {
dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i);
}
SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape)
return true;
}
template <typename indiceT, typename valueT>
bool SparseMatrixTransposeCpuKernelMod::LaunchcomplexKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
indiceT *x_dense_shape = reinterpret_cast<indiceT *>(inputs[kInputIndex0]->addr);
indiceT *x_batch_pointers = reinterpret_cast<indiceT *>(inputs[kInputIndex1]->addr);
indiceT *x_row_pointers = reinterpret_cast<indiceT *>(inputs[kInputIndex2]->addr);
indiceT *x_col_indices = reinterpret_cast<indiceT *>(inputs[kInputIndex3]->addr);
valueT *x_values = reinterpret_cast<valueT *>(inputs[kInputIndex4]->addr);
indiceT *y_dense_shape_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex0]->addr);
indiceT *y_batch_pointers_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex1]->addr);
indiceT *y_row_pointers_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex2]->addr);
indiceT *y_col_indices_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex3]->addr);
valueT *y_values_addr = reinterpret_cast<valueT *>(outputs[kOutputIndex4]->addr);
std::vector<int64_t> y_row_pointers_shape;
int64_t batch_pointers = x_batch_pointers_size_;
if (rank_x_ == kRankWithBatch) {
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape0];
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape2];
y_dense_shape_addr[kDenseShape2] = x_dense_shape[kDenseShape1];
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape0] * (x_dense_shape[kDenseShape2] + 1));
int64_t batch = x_dense_shape[kDenseShape0];
BATCH_CHECK(batch, batch_pointers, kernel_name_)
} else {
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape1];
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape0];
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape1] + 1);
if (batch_pointers != kTwo) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals"
<< "2 to match the CSR form input when input has no batch.";
}
}
for (int64_t i = 0; i < batch_pointers; ++i) {
y_batch_pointers_addr[i] = x_batch_pointers[i];
}
int64_t num_rows = x_dense_shape[rank_x_ - kTwo];
int64_t num_cols = x_dense_shape[rank_x_ - 1];
int64_t num_batch = batch_pointers - 1;
std::vector<int64_t> y_part_row_pointers(num_cols + 1);
std::vector<int64_t> part_row_pointers(num_rows + 1);
if (x_row_pointer_size_ != num_batch * (num_rows + 1)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of row pionters shape should equals"
<< " batch*(rows + 1 ) to match the CSR form input when input has no batch.";
}
for (int64_t j = 0; j < num_batch; ++j) {
int64_t n = x_batch_pointers[j + 1] - x_batch_pointers[j];
std::vector<valueT> part_values(n);
std::vector<indiceT> part_col_indices(n);
std::vector<indiceT> y_part_col_indices(n);
std::vector<valueT> y_part_values(n);
for (int64_t i = 0; i < num_cols + 1; ++i) {
y_part_row_pointers[i] = 0;
}
for (int64_t k = 0; k < num_rows + 1; ++k) {
part_row_pointers[k] = x_row_pointers[(num_rows + 1) * j + k];
}
for (int64_t k = 0; k < n; ++k) {
part_values[k] = x_values[x_batch_pointers[j] + k];
part_col_indices[k] = x_col_indices[x_batch_pointers[j] + k];
}
for (int64_t i = 0; i < n; ++i) {
y_part_row_pointers[part_col_indices[i] + 1] += 1;
}
for (int64_t i = 1; i < num_cols + 1; ++i) {
y_part_row_pointers[i] += y_part_row_pointers[i - 1];
}
for (int64_t k = 0; k < num_cols + 1; ++k) {
y_row_pointers_addr[(num_cols + 1) * j + k] = y_part_row_pointers[k];
}
std::vector<int64_t> current_col_count(num_cols);
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
const int64_t row_begin = part_row_pointers[row_idx];
const int64_t row_end = part_row_pointers[row_idx + 1];
for (int64_t i = row_begin; i < row_end; ++i) {
const int64_t col_idx = part_col_indices[i];
const int64_t offset = y_part_row_pointers[col_idx] + current_col_count[col_idx];
y_part_col_indices[offset] = row_idx;
y_part_values[offset] = part_values[i];
current_col_count[col_idx] += 1;
}
}
for (int64_t k = 0; k < n; ++k) {
y_values_addr[x_batch_pointers[j] + k] = y_part_values[k];
y_col_indices_addr[x_batch_pointers[j] + k] = y_part_col_indices[k];
}
}
if (conjugate == true) {
for (int64_t i = 0; i < x_value_size_; ++i) {
y_values_addr[i] = std::conj(y_values_addr[i]);
}
}
auto node_ = node_wpt_.lock();
NODE_CHECK_AND_OUTPUT_TYPE(node_)
int64_t output_nm = common::AnfAlgo::GetOutputTensorNum(node_);
std::vector<TypeId> dtypes(output_nm);
for (int64_t i = 0; i < output_nm; i++) {
dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i);
}
SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape)
return true;
}
std::vector<KernelAttr> SparseMatrixTransposeCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
ADD_KERNEL(Int32, Int32, Int32, Int32, Int8, Int32, Int32, Int32, Int32, Int8),
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt8, Int32, Int32, Int32, Int32, UInt8),
ADD_KERNEL(Int32, Int32, Int32, Int32, Int16, Int32, Int32, Int32, Int32, Int16),
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt16, Int32, Int32, Int32, Int32, UInt16),
ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32),
ADD_KERNEL(Int32, Int32, Int32, Int32, Int64, Int32, Int32, Int32, Int32, Int64),
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt32, Int32, Int32, Int32, Int32, UInt32),
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt64, Int32, Int32, Int32, Int32, UInt64),
ADD_KERNEL(Int32, Int32, Int32, Int32, Float16, Int32, Int32, Int32, Int32, Float16),
ADD_KERNEL(Int32, Int32, Int32, Int32, Float32, Int32, Int32, Int32, Int32, Float32),
ADD_KERNEL(Int32, Int32, Int32, Int32, Float64, Int32, Int32, Int32, Int32, Float64),
ADD_KERNEL(Int32, Int32, Int32, Int32, Complex64, Int32, Int32, Int32, Int32, Complex64),
ADD_KERNEL(Int32, Int32, Int32, Int32, Complex128, Int32, Int32, Int32, Int32, Complex128),
ADD_KERNEL(Int64, Int64, Int64, Int64, Int8, Int64, Int64, Int64, Int64, Int8),
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt8, Int64, Int64, Int64, Int64, UInt8),
ADD_KERNEL(Int64, Int64, Int64, Int64, Int16, Int64, Int64, Int64, Int64, Int16),
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt16, Int64, Int64, Int64, Int64, UInt16),
ADD_KERNEL(Int64, Int64, Int64, Int64, Int32, Int64, Int64, Int64, Int64, Int32),
ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64),
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt32, Int64, Int64, Int64, Int64, UInt32),
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt64, Int64, Int64, Int64, Int64, UInt64),
ADD_KERNEL(Int64, Int64, Int64, Int64, Float16, Int64, Int64, Int64, Int64, Float16),
ADD_KERNEL(Int64, Int64, Int64, Int64, Float32, Int64, Int64, Int64, Int64, Float32),
ADD_KERNEL(Int64, Int64, Int64, Int64, Float64, Int64, Int64, Int64, Int64, Float64),
ADD_KERNEL(Int64, Int64, Int64, Int64, Complex64, Int64, Int64, Int64, Int64, Complex64),
ADD_KERNEL(Int64, Int64, Int64, Int64, Complex128, Int64, Int64, Int64, Int64, Complex128)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseMatrixTranspose, SparseMatrixTransposeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_
#include <algorithm>
#include <complex>
#include <iostream>
#include <map>
#include <functional>
#include <numeric>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
class SparseMatrixTransposeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SparseMatrixTransposeCpuKernelMod() = default;
~SparseMatrixTransposeCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename indiceT, typename valueT>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename indiceT, typename valueT>
bool LaunchcomplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
int64_t x_batch_pointers_size_;
int64_t x_value_size_;
int64_t x_col_indice_size_;
int64_t x_row_pointer_size_;
int64_t rank_x_;
bool conjugate;
TypeId indiceT_{kTypeUnknown};
TypeId valueT_{kTypeUnknown};
CNodeWeakPtr node_wpt_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_

View File

@ -279,6 +279,7 @@ constexpr auto kSparseMatrixAdd = "SparseMatrixAdd";
constexpr auto kSparseAdd = "SparseAdd";
constexpr auto kSparseConcat = "SparseConcat";
constexpr auto kSparseMatrixNNZ = "SparseMatrixNNZ";
constexpr auto kSparseMatrixTranspose = "SparseMatrixTranspose";
// Sparse Grad ops
constexpr auto kSparseAddGrad = "SparseAddGrad";
@ -886,6 +887,7 @@ GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToSparseTensor, std::make_shared<Prim
GVAR_DEF(PrimitivePtr, kPrimSparseConcat, std::make_shared<Primitive>(kSparseConcat));
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixNNZ, std::make_shared<Primitive>(kSparseMatrixNNZ));
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToDense, std::make_shared<Primitive>("CSRSparseMatrixToDense"));
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixTranspose, std::make_shared<Primitive>(kSparseMatrixTranspose));
// Sparse Grad ops
GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared<Primitive>(kSparseAddGrad));

View File

@ -0,0 +1,163 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <map>
#include <string>
#include "ops/sparse_matrix_transpose.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr SparseMatrixTransposeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const int64_t kInputNoBatch = 2;
const int64_t kInputWithBatch = 3;
const int64_t ktwo = 2;
std::vector<int64_t> dense_shape_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
const int64_t rank_x = dense_shape_shape[0];
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
const int64_t max_length = GetValue<int64_t>(max_length_ptr);
std::vector<int64_t> batch_pointers_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
std::vector<int64_t> row_pointers_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
std::vector<int64_t> col_indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
std::vector<int64_t> values_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
if (rank_x != kInputNoBatch && rank_x != kInputWithBatch) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the rank of input must be 2 or 3, but got "
<< dense_shape_shape.size() << "!";
}
if (batch_pointers_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input batch pointers must be 1-D, but got "
<< batch_pointers_shape.size() << "-D.";
}
if (row_pointers_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input row pointers must be 1-D, but got "
<< row_pointers_shape.size() << "-D.";
}
if (col_indices_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input col indices must be 1-D, but got "
<< col_indices_shape.size() << "-D.";
}
if (values_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input col indices must be 1-D, but got "
<< col_indices_shape.size() << "-D.";
}
auto transpose_shape_shape = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(transpose_shape_shape);
abstract::ShapePtr transpose_shape_shape_list = transpose_shape_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(transpose_shape_shape_list);
auto transpose_batch_shape = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(transpose_batch_shape);
abstract::ShapePtr transpose_batch_shape_list = transpose_batch_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(transpose_batch_shape_list);
auto transpose_col_indices_shape = input_args[kInputIndex3]->BuildShape();
MS_EXCEPTION_IF_NULL(transpose_col_indices_shape);
abstract::ShapePtr transpose_col_indices_shape_list = transpose_col_indices_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(transpose_col_indices_shape_list);
auto transpose_values_shape = input_args[kInputIndex4]->BuildShape();
MS_EXCEPTION_IF_NULL(transpose_values_shape);
abstract::ShapePtr transpose_values_shape_list = transpose_values_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(transpose_values_shape_list);
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
!input_args[kInputIndex0]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex0]->BuildValue()->isa<None>()) {
auto dense_shape = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(dense_shape);
auto dense_shape_ptr = dense_shape->BuildValue();
MS_EXCEPTION_IF_NULL(dense_shape_ptr);
auto dense_shape_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue("dense_shape", dense_shape_ptr, prim_name);
ShapeVector transpose_row_pointers_shape = {0};
if (rank_x == kInputNoBatch) {
transpose_row_pointers_shape[0] = dense_shape_ptr_tensor[1] + 1;
} else {
transpose_row_pointers_shape[0] = dense_shape_ptr_tensor[0] * (dense_shape_ptr_tensor[ktwo] + 1);
}
if (transpose_row_pointers_shape[0] > max_length) {
MS_EXCEPTION(ValueError) << "For " << prim_name << "the shape of output row pointers must be "
<< "less than max length: " << max_length << ", but got "
<< transpose_row_pointers_shape[0]
<< "! The shape of output row pointers should be reduced"
<< " or max_length should be increased.";
}
ShapeVector transpose_row_pointer_min_shape = {0};
ShapeVector transpose_row_pointer_max_shape = {max_length};
abstract::ShapePtr transpose_row_pointers_shape_list = std::make_shared<abstract::Shape>(
transpose_row_pointers_shape, transpose_row_pointer_min_shape, transpose_row_pointer_max_shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
transpose_shape_shape_list, transpose_batch_shape_list, transpose_row_pointers_shape_list,
transpose_col_indices_shape_list, transpose_values_shape_list});
} else {
ShapeVector transpose_row_pointers_shape = {abstract::Shape::SHP_ANY};
ShapeVector transpose_row_pointer_min_shape = {0};
ShapeVector transpose_row_pointer_max_shape = {max_length};
abstract::ShapePtr transpose_row_pointers_shape_list = std::make_shared<abstract::Shape>(
transpose_row_pointers_shape, transpose_row_pointer_min_shape, transpose_row_pointer_max_shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
transpose_shape_shape_list, transpose_batch_shape_list, transpose_row_pointers_shape_list,
transpose_col_indices_shape_list, transpose_values_shape_list});
}
}
TuplePtr SparseMatrixTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> index_valid_types = {kInt32, kInt64};
const std::set<TypePtr> values_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
auto dense_shape_type = input_args[kInputIndex0]->BuildType();
auto batch_type = input_args[kInputIndex1]->BuildType();
auto row_type = input_args[kInputIndex2]->BuildType();
auto col_type = input_args[kInputIndex3]->BuildType();
auto value_type = input_args[kInputIndex4]->BuildType();
std::map<std::string, TypePtr> types;
(void)types.emplace("x_dense_shape", dense_shape_type);
(void)types.emplace("x_batch_pointers", batch_type);
(void)types.emplace("x_row_pointers", row_type);
(void)types.emplace("x_col_indices", col_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, index_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_values", value_type, values_valid_types, prim->name());
std::vector<TypePtr> types_list = {input_args[kInputIndex0]->BuildType(), input_args[kInputIndex1]->BuildType(),
input_args[kInputIndex2]->BuildType(), input_args[kInputIndex3]->BuildType(),
input_args[kInputIndex4]->BuildType()};
return std::make_shared<Tuple>(types_list);
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseMatrixTranspose, BaseOperator);
AbstractBasePtr SparseMatrixTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 5;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = SparseMatrixTransposeInferType(primitive, input_args);
auto infer_shape = SparseMatrixTransposeInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SparseMatrixTranspose, prim::kPrimSparseMatrixTranspose, SparseMatrixTransposeInfer,
nullptr, true);
REGISTER_HOST_DEPENDS(kNameSparseMatrixTranspose, {0});
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_
#define MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseMatrixTranspose = "SparseMatrixTranspose";
/// \brief Return the transpose of input CSR tensor.
class MIND_API SparseMatrixTranspose : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseMatrixTranspose);
/// \brief Constructor.
SparseMatrixTranspose() : BaseOperator(kNameSparseMatrixTranspose) {
InitIOName({"x_dense_shape", "x_batch_pointers", "x_row_pointers", "x_col_indices", "x_values"},
{"y_dense_shape", "y_batch_pointers", "y_row_pointers", "y_col_indices", "y_values"});
}
};
abstract::AbstractBasePtr SparseMatrixTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSparseMatrixTransposePtr = std::shared_ptr<SparseMatrixTranspose>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -4,6 +4,7 @@ from .. import functional as F
from .._grad.grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.sparse_ops import SparseTensorDenseAdd
from ..operations.sparse_ops import SparseMatrixTranspose
@bprop_getters.register(SparseTensorDenseAdd)
@ -12,3 +13,15 @@ def get_bprop_sparse_tensor_dense_add(self):
def bprop(x1_indices, x1_values, x1_shape, x2, out, dout):
return (zeros_like(x1_indices), F.gather_nd(dout, x1_indices), zeros_like(x1_shape), dout,)
return bprop
@bprop_getters.register(SparseMatrixTranspose)
def get_bprop_sparse_matrix_transpose(self):
"""Grad definition for 'SparseMatrixTranspose' operation"""
sparse_transpose = SparseMatrixTranspose(conjugate=self.conjugate)
def bprop(x_dense_shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values, out, dout):
dx = sparse_transpose(dout[0], dout[1], dout[2], dout[3], dout[4])
dx_all = (dx[0], dx[1], dx[2], dx[3], dx[4])
return dx_all
return bprop

View File

@ -272,3 +272,4 @@ from .pow import _pow_aicpu
from .depth_to_space import _depth_to_space_aicpu
from .space_to_depth import _space_to_depth_aicpu
from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
from .sparse_matrix_transpose import _sparse_matrix_transpose_aicpu

View File

@ -0,0 +1,116 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseMatrixTranspose op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sparse_matrix_transpose_op_info = AiCPURegOp("SparseMatrixTranspose") \
.fusion_type("OPAQUE") \
.attr("conjugate", "bool") \
.input(0, "x_dense_shape", "required") \
.input(1, "x_batch_pointers", "required") \
.input(2, "x_row_pointers", "required") \
.input(3, "x_col_indices", "required") \
.input(4, "x_values", "required") \
.output(0, "y_dense_shape", "required") \
.output(1, "y_batch_pointers", "required") \
.output(2, "y_row_pointers", "required") \
.output(3, "y_col_indices", "required") \
.output(4, "y_values", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.C64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.C128_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.C64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(sparse_matrix_transpose_op_info)
def _sparse_matrix_transpose_aicpu():
"""SparseMatrixTranspose AiCPU register"""
return

View File

@ -20,7 +20,7 @@
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, Primitive, prim_attr_register
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
class SparseToDense(PrimitiveWithInfer):
@ -749,3 +749,91 @@ class CSRSparseMatrixToDense(Primitive):
self.init_prim_io_names(
inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'],
outputs=['y'])
class SparseMatrixTranspose(Primitive):
r"""
Return the transpose of sparse matrix or sparse matrixs.
If the sparse matrix input contains batch dimension, then output dimension will be same with the batch dimension.
The rank of sparse matrix input must be equal to `2` or `3`.
Note:
It is assumed that all the inputs can form a legal CSR sparse matrix, otherwise this operator is not defined.
Args:
conjugate (bool): If True, the output sparse tensor is conjugated . Default: False.
Inputs:
- **dense_shape** (Tensor) - A 1-D Tensor, represents the shape of input sparse matrix under dense status.
Support int32, int64. The shape is :math:`(2,)` or :math:`(3,)`.
- **batch_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements number in each batch.
Support int32, int64, takes on values: :math:`(0, nnz[0], nnz[0] + nnz[1], ..., total\_nnz)`.
If there are `n` batch within input sparse matrix, the shape is :math:`(n+1)`.
- **row_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements of each row.
Support int32, int64, takes on values:
:math:`(0, num\_rows\{b\}[0], num\_rows\{b\}[0] + num\_rows\{b\}[1], ..., nnz[b])`,
for :math:`b = 0, ..., n - 1`.
If there are `n` batch within input sparse matrix and dense shape is :math:`(rows,cols)`,
the shape is :math:`((rows + 1) * n)`.
Note: num_rows{0}[0] means the non-zero elements number in the first row of first sparse matrix.
- **col_indices** (Tensor) - A 1-D Tensor, represents the column values for the given row and column index.
Support int32, int64. The shape is :math:`(M)`,
where `M` is the number of non-zero elements in all input sparse matrix.
- **values** (Tensor) - A 1-D Tensor, represents the actual values for the given row and column index.
Support BasicType. The shape is :math:`(M)`, where `M` is the number of non-zero elements in all
input sparse matrix.
Outputs:
- **dense_shape** (Tensor) - A 1-D Tensor, represents the shape of output sparse matrix under dense status.
Support int32, int64. The shape is the same as the input sparse matrix.
- **batch_pointers** (Tensor) - A 1-D Tensor, which is the same as the input sparse matrix's batch_pointers.
- **row_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements of each row of output sparse
matrix. Support int32, int64, takes on values:
:math:`(0, num\_rows\{b\}[0], num\_rows\{b\}[0] + num\_rows\{b\}[1], ..., nnz[b])`,
for :math:`b = 0, ..., n - 1`.
If there are `n` batch within output sparse matrix and dense shape is :math:`(rows,cols)`,
the shape is :math:`((rows + 1) * n)`.
Note: num_rows{0}[0] means the non-zero elements number in the first row of first sparse matrix.
- **col_indices** (Tensor) - A 1-D Tensor, represents the column values for the given row and column index.
Support int32, int64. The shape is :math:`(M)`,
where `M` is the number of non-zero elements in all input sparse matrix.
- **values** (Tensor) - A 1-D Tensor, which is the same as the input sparse matrix's values.
Raises:
TypeError: If dtype of `values` doesn't meet the parameter description.
TypeError: The data type of `dense_shape, batch_pointers, row_pointers, col_indices` is not int32 or int64.
ValueError: If rank of `dense_shape` is not 2 or 3.
TypeError: The input data should have the correct CSR form.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> from mindspore.ops import operations as ops
>>> dense_shape = Tensor([2,3], dtype=ms.int32)
>>> batch_pointers = Tensor([0,1], dtype=ms.int32)
>>> row_pointers = Tensor([0,1,1], dtype=ms.int32)
>>> col_indices = Tensor([0], dtype=ms.int32)
>>> values = Tensor([99], dtype=ms.float32)
>>> sparse_matrix_transpose = ops.SparseMatrixTranspose()
>>> output = sparse_matrix_transpose(dense_shape, batch_pointers, row_pointers, col_indices, values)
>>> print(output[0])
[3 2]
>>> print(output[1])
[0 1]
>>> print(output[2])
[0 1 1 1]
>>> print(output[3])
[0]
>>> print(output[4])
[99.]
"""
@prim_attr_register
def __init__(self, conjugate=False):
"""Initialize SparseMatrixTranspose"""
validator.check_value_type("conjugate", conjugate, [bool], self.name)
self.add_prim_attr("max_length", 100000000)
self.init_prim_io_names(inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers',
'x_col_indices', 'x_values'],
outputs=['y_dense_shape', 'y_batch_pointers', 'y_row_pointers',
'y_col_indices', 'y_values'])

View File

@ -108,6 +108,7 @@ from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm
from mindspore.ops.operations.sparse_ops import SparseTensorDenseMatmul
from mindspore.ops.operations.sparse_ops import SparseMatrixNNZ
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
from mindspore.ops.operations.sparse_ops import SparseMatrixTranspose
from mindspore.ops.operations.other_ops import BlackmanWindow
from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp
from mindspore.nn.layer import normalization
@ -4022,6 +4023,22 @@ test_case_quant_ops = [
'block': inner.Quant(80.0, 10.0, False, "Round"),
'desc_inputs': [Tensor([100.0, 200.0], mstype.float16)],
'skip': ['backward']}),
('SparseMatrixTranspose1', {
'block': SparseMatrixTranspose(conjugate=False),
'desc_inputs': [Tensor(np.array([2, 4]).astype(np.int32)),
Tensor(np.array([0, 2]).astype(np.int32)),
Tensor(np.array([0, 2, 2]).astype(np.int32)),
Tensor(np.array([0, 2]).astype(np.int32)),
Tensor(np.array([5.3, 2.4]).astype(np.float32))],
'skip': ['backward']}),
('SparseMatrixTranspose2', {
'block': SparseMatrixTranspose(conjugate=True),
'desc_inputs': [Tensor(np.array([2, 4]).astype(np.int32)),
Tensor(np.array([0, 2]).astype(np.int32)),
Tensor(np.array([0, 2, 2]).astype(np.int32)),
Tensor(np.array([0, 2]).astype(np.int32)),
Tensor(np.array([5.3, 2.4]).astype(np.float32))],
'skip': ['backward']}),
]
test_case_sparse_ops = [