diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.cc new file mode 100644 index 00000000000..3dfdc84359a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.cc @@ -0,0 +1,410 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr int64_t kTwo = 2; +constexpr int64_t kInputsNum = 5; +constexpr int64_t kOutputsNum = 5; +constexpr int64_t kInputIndex0 = 0; +constexpr int64_t kInputIndex1 = 1; +constexpr int64_t kInputIndex2 = 2; +constexpr int64_t kInputIndex3 = 3; +constexpr int64_t kInputIndex4 = 4; +constexpr int64_t kOutputIndex0 = 0; +constexpr int64_t kOutputIndex1 = 1; +constexpr int64_t kOutputIndex2 = 2; +constexpr int64_t kOutputIndex3 = 3; +constexpr int64_t kOutputIndex4 = 4; +constexpr int64_t kDenseShape0 = 0; +constexpr int64_t kDenseShape1 = 1; +constexpr int64_t kDenseShape2 = 2; +constexpr int64_t kRankWithOutBatch = 2; +constexpr int64_t kRankWithBatch = 3; +KernelAttr AddKernel(const TypeId &ms_type1, const TypeId &ms_type2, const TypeId &ms_type3, const TypeId &ms_type4, + const TypeId &ms_type5, const TypeId &ms_type6, const TypeId &ms_type7, const TypeId &ms_type8, + const TypeId &ms_type9, const TypeId &ms_type10) { + auto kernel = KernelAttr() + .AddInputAttr(ms_type1) + .AddInputAttr(ms_type2) + .AddInputAttr(ms_type3) + .AddInputAttr(ms_type4) + .AddInputAttr(ms_type5) + .AddOutputAttr(ms_type6) + .AddOutputAttr(ms_type7) + .AddOutputAttr(ms_type8) + .AddOutputAttr(ms_type9) + .AddOutputAttr(ms_type10); + return kernel; +} + +#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \ + AddKernel(kNumberType##t1, kNumberType##t2, kNumberType##t3, kNumberType##t4, kNumberType##t5, kNumberType##t6, \ + kNumberType##t7, kNumberType##t8, kNumberType##t9, kNumberType##t10) + +#define SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(DTYPE, VTYPEONE, VTYPETWO, inputs, outputs) \ + case (DTYPE): { \ + LaunchKernel(inputs, outputs); \ + break; \ + } + +#define SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(DTYPE, VTYPEONE, VTYPETWO, inputs, outputs) \ + case (DTYPE): { \ + LaunchcomplexKernel(inputs, outputs); \ + break; \ + } + +#define NODE_CHECK_AND_OUTPUT_TYPE(node_) \ + do { \ + if (!node_) { \ + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; \ + } \ + } while (0); + +#define SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape) \ + common::AnfAlgo::SetOutputInferTypeAndShape( \ + dtypes, \ + {common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex0), \ + common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex1), y_row_pointers_shape, \ + common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex3), \ + common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex4)}, \ + node_.get()); + +#define BATCH_CHECK(batch, batch_pointers, kernel_name_) \ + do { \ + if (batch + 1 != batch_pointers) { \ + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals" \ + << "dense shape[0] + 1 to match the CSR form input when input has batch."; \ + } \ + } while (0); +} // namespace + +void SparseMatrixTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + int64_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_); + int64_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_); + std::vector input_dense_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex0); + rank_x_ = input_dense_shape[0]; + if (rank_x_ != kRankWithOutBatch && rank_x_ != kRankWithBatch) { + MS_LOG(EXCEPTION) << "For SparseMatrixTranspose,the rank must be 2 or 3, but got" << rank_x_ << "!"; + } + conjugate = common::AnfAlgo::GetNodeAttr(kernel_node, "conjugate"); + std::vector x_batch_pointers_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex1); + std::vector x_row_pointers_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex2); + std::vector x_col_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex3); + std::vector x_value_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex4); + x_value_size_ = x_value_shape[0]; + x_batch_pointers_size_ = x_batch_pointers_shape[0]; + x_col_indice_size_ = x_col_indices_shape[0]; + x_row_pointer_size_ = x_row_pointers_shape[0]; + if (x_col_indice_size_ != x_value_size_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of col indice shape should equals " + << "values shape to match the CSR form input."; + } + node_wpt_ = kernel_node; + indiceT_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex0); + valueT_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex4); +} + +bool SparseMatrixTransposeCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + switch (indiceT_) { + case kNumberTypeInt32: + switch (valueT_) { + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt8, int32_t, int8_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt8, int32_t, uint8_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt16, int32_t, int16_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt16, int32_t, uint16_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt32, int32_t, uint32_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt64, int32_t, int64_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt64, int32_t, uint64_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat16, int32_t, float16, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat32, int32_t, float, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat64, int32_t, double, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex64, int32_t, complex64, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex128, int32_t, complex128, inputs, outputs) + default: + MS_LOG(EXCEPTION) << "data type of values is not included."; + break; + } + break; + case kNumberTypeInt64: + switch (valueT_) { + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt8, int64_t, int8_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt8, int64_t, uint8_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt16, int64_t, int16_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt16, int64_t, uint16_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt32, int64_t, int32_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt32, int64_t, uint32_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt64, int64_t, uint64_t, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat16, int64_t, float16, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat32, int64_t, float, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat64, int64_t, double, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex64, int64_t, complex64, inputs, outputs) + SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex128, int64_t, complex128, inputs, outputs) + default: + MS_LOG(EXCEPTION) << "data type of values is not included."; + break; + } + break; + default: + MS_LOG(EXCEPTION) << "The data type of dense_shape, batch_pointers, " + << "row_pointers, col_indices is not int32 or int64."; + break; + } + return true; +} + +template +bool SparseMatrixTransposeCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + indiceT *x_dense_shape = static_cast(inputs[kInputIndex0]->addr); + indiceT *x_batch_pointers = static_cast(inputs[kInputIndex1]->addr); + indiceT *x_row_pointers = static_cast(inputs[kInputIndex2]->addr); + indiceT *x_col_indices = static_cast(inputs[kInputIndex3]->addr); + valueT *x_values = static_cast(inputs[kInputIndex4]->addr); + indiceT *y_dense_shape_addr = static_cast(outputs[kOutputIndex0]->addr); + indiceT *y_batch_pointers_addr = static_cast(outputs[kOutputIndex1]->addr); + indiceT *y_row_pointers_addr = static_cast(outputs[kOutputIndex2]->addr); + indiceT *y_col_indices_addr = static_cast(outputs[kOutputIndex3]->addr); + valueT *y_values_addr = static_cast(outputs[kOutputIndex4]->addr); + std::vector y_row_pointers_shape; + int64_t batch_pointers = x_batch_pointers_size_; + if (rank_x_ == kRankWithBatch) { + y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape0]; + y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape2]; + y_dense_shape_addr[kDenseShape2] = x_dense_shape[kDenseShape1]; + y_row_pointers_shape.push_back(x_dense_shape[kDenseShape0] * (x_dense_shape[kDenseShape2] + 1)); + int64_t batch = x_dense_shape[kDenseShape0]; + BATCH_CHECK(batch, batch_pointers, kernel_name_) + } else { + y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape1]; + y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape0]; + y_row_pointers_shape.push_back(x_dense_shape[kDenseShape1] + 1); + if (batch_pointers != kTwo) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals" + << "2 to match the CSR form input when input has no batch."; + } + } + for (int64_t i = 0; i < batch_pointers; ++i) { + y_batch_pointers_addr[i] = x_batch_pointers[i]; + } + int64_t num_rows = x_dense_shape[rank_x_ - kTwo]; + int64_t num_cols = x_dense_shape[rank_x_ - 1]; + int64_t num_batch = batch_pointers - 1; + std::vector y_part_row_pointers(num_cols + 1); + std::vector part_row_pointers(num_rows + 1); + if (x_row_pointer_size_ != num_batch * (num_rows + 1)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of row pionters shape should equals" + << " batch*(rows + 1 ) to match the CSR form input when input has no batch."; + } + for (int64_t j = 0; j < num_batch; ++j) { + int64_t n = x_batch_pointers[j + 1] - x_batch_pointers[j]; + std::vector part_values(n); + std::vector part_col_indices(n); + std::vector y_part_col_indices(n); + std::vector y_part_values(n); + for (int64_t i = 0; i < num_cols + 1; ++i) { + y_part_row_pointers[i] = 0; + } + for (int64_t k = 0; k < num_rows + 1; ++k) { + part_row_pointers[k] = x_row_pointers[(num_rows + 1) * j + k]; + } + for (int64_t k = 0; k < n; ++k) { + part_values[k] = x_values[x_batch_pointers[j] + k]; + part_col_indices[k] = x_col_indices[x_batch_pointers[j] + k]; + } + for (int64_t i = 0; i < n; ++i) { + y_part_row_pointers[part_col_indices[i] + 1] += 1; + } + for (int64_t i = 1; i < num_cols + 1; ++i) { + y_part_row_pointers[i] += y_part_row_pointers[i - 1]; + } + for (int64_t k = 0; k < num_cols + 1; ++k) { + y_row_pointers_addr[(num_cols + 1) * j + k] = y_part_row_pointers[k]; + } + std::vector current_col_count(num_cols); + for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) { + const int64_t row_begin = part_row_pointers[row_idx]; + const int64_t row_end = part_row_pointers[row_idx + 1]; + for (int64_t i = row_begin; i < row_end; ++i) { + const int64_t col_idx = part_col_indices[i]; + const int64_t offset = y_part_row_pointers[col_idx] + current_col_count[col_idx]; + y_part_col_indices[offset] = row_idx; + y_part_values[offset] = part_values[i]; + current_col_count[col_idx] += 1; + } + } + for (int64_t k = 0; k < n; ++k) { + y_values_addr[x_batch_pointers[j] + k] = y_part_values[k]; + y_col_indices_addr[x_batch_pointers[j] + k] = y_part_col_indices[k]; + } + } + auto node_ = node_wpt_.lock(); + NODE_CHECK_AND_OUTPUT_TYPE(node_) + int64_t output_nm = common::AnfAlgo::GetOutputTensorNum(node_); + std::vector dtypes(output_nm); + for (int64_t i = 0; i < output_nm; i++) { + dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i); + } + SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape) + return true; +} + +template +bool SparseMatrixTransposeCpuKernelMod::LaunchcomplexKernel(const std::vector &inputs, + const std::vector &outputs) { + indiceT *x_dense_shape = reinterpret_cast(inputs[kInputIndex0]->addr); + indiceT *x_batch_pointers = reinterpret_cast(inputs[kInputIndex1]->addr); + indiceT *x_row_pointers = reinterpret_cast(inputs[kInputIndex2]->addr); + indiceT *x_col_indices = reinterpret_cast(inputs[kInputIndex3]->addr); + valueT *x_values = reinterpret_cast(inputs[kInputIndex4]->addr); + indiceT *y_dense_shape_addr = reinterpret_cast(outputs[kOutputIndex0]->addr); + indiceT *y_batch_pointers_addr = reinterpret_cast(outputs[kOutputIndex1]->addr); + indiceT *y_row_pointers_addr = reinterpret_cast(outputs[kOutputIndex2]->addr); + indiceT *y_col_indices_addr = reinterpret_cast(outputs[kOutputIndex3]->addr); + valueT *y_values_addr = reinterpret_cast(outputs[kOutputIndex4]->addr); + std::vector y_row_pointers_shape; + int64_t batch_pointers = x_batch_pointers_size_; + if (rank_x_ == kRankWithBatch) { + y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape0]; + y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape2]; + y_dense_shape_addr[kDenseShape2] = x_dense_shape[kDenseShape1]; + y_row_pointers_shape.push_back(x_dense_shape[kDenseShape0] * (x_dense_shape[kDenseShape2] + 1)); + int64_t batch = x_dense_shape[kDenseShape0]; + BATCH_CHECK(batch, batch_pointers, kernel_name_) + } else { + y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape1]; + y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape0]; + y_row_pointers_shape.push_back(x_dense_shape[kDenseShape1] + 1); + if (batch_pointers != kTwo) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals" + << "2 to match the CSR form input when input has no batch."; + } + } + for (int64_t i = 0; i < batch_pointers; ++i) { + y_batch_pointers_addr[i] = x_batch_pointers[i]; + } + int64_t num_rows = x_dense_shape[rank_x_ - kTwo]; + int64_t num_cols = x_dense_shape[rank_x_ - 1]; + int64_t num_batch = batch_pointers - 1; + std::vector y_part_row_pointers(num_cols + 1); + std::vector part_row_pointers(num_rows + 1); + if (x_row_pointer_size_ != num_batch * (num_rows + 1)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of row pionters shape should equals" + << " batch*(rows + 1 ) to match the CSR form input when input has no batch."; + } + for (int64_t j = 0; j < num_batch; ++j) { + int64_t n = x_batch_pointers[j + 1] - x_batch_pointers[j]; + std::vector part_values(n); + std::vector part_col_indices(n); + std::vector y_part_col_indices(n); + std::vector y_part_values(n); + for (int64_t i = 0; i < num_cols + 1; ++i) { + y_part_row_pointers[i] = 0; + } + for (int64_t k = 0; k < num_rows + 1; ++k) { + part_row_pointers[k] = x_row_pointers[(num_rows + 1) * j + k]; + } + for (int64_t k = 0; k < n; ++k) { + part_values[k] = x_values[x_batch_pointers[j] + k]; + part_col_indices[k] = x_col_indices[x_batch_pointers[j] + k]; + } + for (int64_t i = 0; i < n; ++i) { + y_part_row_pointers[part_col_indices[i] + 1] += 1; + } + for (int64_t i = 1; i < num_cols + 1; ++i) { + y_part_row_pointers[i] += y_part_row_pointers[i - 1]; + } + for (int64_t k = 0; k < num_cols + 1; ++k) { + y_row_pointers_addr[(num_cols + 1) * j + k] = y_part_row_pointers[k]; + } + std::vector current_col_count(num_cols); + for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) { + const int64_t row_begin = part_row_pointers[row_idx]; + const int64_t row_end = part_row_pointers[row_idx + 1]; + for (int64_t i = row_begin; i < row_end; ++i) { + const int64_t col_idx = part_col_indices[i]; + const int64_t offset = y_part_row_pointers[col_idx] + current_col_count[col_idx]; + y_part_col_indices[offset] = row_idx; + y_part_values[offset] = part_values[i]; + current_col_count[col_idx] += 1; + } + } + for (int64_t k = 0; k < n; ++k) { + y_values_addr[x_batch_pointers[j] + k] = y_part_values[k]; + y_col_indices_addr[x_batch_pointers[j] + k] = y_part_col_indices[k]; + } + } + if (conjugate == true) { + for (int64_t i = 0; i < x_value_size_; ++i) { + y_values_addr[i] = std::conj(y_values_addr[i]); + } + } + auto node_ = node_wpt_.lock(); + NODE_CHECK_AND_OUTPUT_TYPE(node_) + int64_t output_nm = common::AnfAlgo::GetOutputTensorNum(node_); + std::vector dtypes(output_nm); + for (int64_t i = 0; i < output_nm; i++) { + dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i); + } + SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape) + return true; +} +std::vector SparseMatrixTransposeCpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = { + ADD_KERNEL(Int32, Int32, Int32, Int32, Int8, Int32, Int32, Int32, Int32, Int8), + ADD_KERNEL(Int32, Int32, Int32, Int32, UInt8, Int32, Int32, Int32, Int32, UInt8), + ADD_KERNEL(Int32, Int32, Int32, Int32, Int16, Int32, Int32, Int32, Int32, Int16), + ADD_KERNEL(Int32, Int32, Int32, Int32, UInt16, Int32, Int32, Int32, Int32, UInt16), + ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32), + ADD_KERNEL(Int32, Int32, Int32, Int32, Int64, Int32, Int32, Int32, Int32, Int64), + ADD_KERNEL(Int32, Int32, Int32, Int32, UInt32, Int32, Int32, Int32, Int32, UInt32), + ADD_KERNEL(Int32, Int32, Int32, Int32, UInt64, Int32, Int32, Int32, Int32, UInt64), + ADD_KERNEL(Int32, Int32, Int32, Int32, Float16, Int32, Int32, Int32, Int32, Float16), + ADD_KERNEL(Int32, Int32, Int32, Int32, Float32, Int32, Int32, Int32, Int32, Float32), + ADD_KERNEL(Int32, Int32, Int32, Int32, Float64, Int32, Int32, Int32, Int32, Float64), + ADD_KERNEL(Int32, Int32, Int32, Int32, Complex64, Int32, Int32, Int32, Int32, Complex64), + ADD_KERNEL(Int32, Int32, Int32, Int32, Complex128, Int32, Int32, Int32, Int32, Complex128), + ADD_KERNEL(Int64, Int64, Int64, Int64, Int8, Int64, Int64, Int64, Int64, Int8), + ADD_KERNEL(Int64, Int64, Int64, Int64, UInt8, Int64, Int64, Int64, Int64, UInt8), + ADD_KERNEL(Int64, Int64, Int64, Int64, Int16, Int64, Int64, Int64, Int64, Int16), + ADD_KERNEL(Int64, Int64, Int64, Int64, UInt16, Int64, Int64, Int64, Int64, UInt16), + ADD_KERNEL(Int64, Int64, Int64, Int64, Int32, Int64, Int64, Int64, Int64, Int32), + ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64), + ADD_KERNEL(Int64, Int64, Int64, Int64, UInt32, Int64, Int64, Int64, Int64, UInt32), + ADD_KERNEL(Int64, Int64, Int64, Int64, UInt64, Int64, Int64, Int64, Int64, UInt64), + ADD_KERNEL(Int64, Int64, Int64, Int64, Float16, Int64, Int64, Int64, Int64, Float16), + ADD_KERNEL(Int64, Int64, Int64, Int64, Float32, Int64, Int64, Int64, Int64, Float32), + ADD_KERNEL(Int64, Int64, Int64, Int64, Float64, Int64, Int64, Int64, Int64, Float64), + ADD_KERNEL(Int64, Int64, Int64, Int64, Complex64, Int64, Int64, Int64, Int64, Complex64), + ADD_KERNEL(Int64, Int64, Int64, Int64, Complex128, Int64, Int64, Int64, Int64, Complex128)}; + + return kernel_attr_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseMatrixTranspose, SparseMatrixTransposeCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.h new file mode 100644 index 00000000000..627dcd6d8b1 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.h @@ -0,0 +1,72 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +using complex64 = std::complex; +using complex128 = std::complex; + +class SparseMatrixTransposeCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + SparseMatrixTransposeCpuKernelMod() = default; + ~SparseMatrixTransposeCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + template + bool LaunchcomplexKernel(const std::vector &inputs, const std::vector &outputs); + + int64_t x_batch_pointers_size_; + int64_t x_value_size_; + int64_t x_col_indice_size_; + int64_t x_row_pointer_size_; + int64_t rank_x_; + bool conjugate; + TypeId indiceT_{kTypeUnknown}; + TypeId valueT_{kTypeUnknown}; + CNodeWeakPtr node_wpt_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 30e235b2823..43238ea7c8c 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -279,6 +279,7 @@ constexpr auto kSparseMatrixAdd = "SparseMatrixAdd"; constexpr auto kSparseAdd = "SparseAdd"; constexpr auto kSparseConcat = "SparseConcat"; constexpr auto kSparseMatrixNNZ = "SparseMatrixNNZ"; +constexpr auto kSparseMatrixTranspose = "SparseMatrixTranspose"; // Sparse Grad ops constexpr auto kSparseAddGrad = "SparseAddGrad"; @@ -886,6 +887,7 @@ GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToSparseTensor, std::make_shared(kSparseConcat)); GVAR_DEF(PrimitivePtr, kPrimSparseMatrixNNZ, std::make_shared(kSparseMatrixNNZ)); GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToDense, std::make_shared("CSRSparseMatrixToDense")); +GVAR_DEF(PrimitivePtr, kPrimSparseMatrixTranspose, std::make_shared(kSparseMatrixTranspose)); // Sparse Grad ops GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared(kSparseAddGrad)); diff --git a/mindspore/core/ops/sparse_matrix_transpose.cc b/mindspore/core/ops/sparse_matrix_transpose.cc new file mode 100644 index 00000000000..5aab2c62525 --- /dev/null +++ b/mindspore/core/ops/sparse_matrix_transpose.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "ops/sparse_matrix_transpose.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::TupleShapePtr SparseMatrixTransposeInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); + const int64_t kInputNoBatch = 2; + const int64_t kInputWithBatch = 3; + const int64_t ktwo = 2; + std::vector dense_shape_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + const int64_t rank_x = dense_shape_shape[0]; + auto max_length_ptr = primitive->GetAttr("max_length"); + MS_EXCEPTION_IF_NULL(max_length_ptr); + const int64_t max_length = GetValue(max_length_ptr); + std::vector batch_pointers_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + std::vector row_pointers_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + std::vector col_indices_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + std::vector values_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; + if (rank_x != kInputNoBatch && rank_x != kInputWithBatch) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ",the rank of input must be 2 or 3, but got " + << dense_shape_shape.size() << "!"; + } + if (batch_pointers_shape.size() != 1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input batch pointers must be 1-D, but got " + << batch_pointers_shape.size() << "-D."; + } + if (row_pointers_shape.size() != 1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input row pointers must be 1-D, but got " + << row_pointers_shape.size() << "-D."; + } + if (col_indices_shape.size() != 1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input col indices must be 1-D, but got " + << col_indices_shape.size() << "-D."; + } + if (values_shape.size() != 1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input col indices must be 1-D, but got " + << col_indices_shape.size() << "-D."; + } + auto transpose_shape_shape = input_args[kInputIndex0]->BuildShape(); + MS_EXCEPTION_IF_NULL(transpose_shape_shape); + abstract::ShapePtr transpose_shape_shape_list = transpose_shape_shape->cast(); + MS_EXCEPTION_IF_NULL(transpose_shape_shape_list); + auto transpose_batch_shape = input_args[kInputIndex1]->BuildShape(); + MS_EXCEPTION_IF_NULL(transpose_batch_shape); + abstract::ShapePtr transpose_batch_shape_list = transpose_batch_shape->cast(); + MS_EXCEPTION_IF_NULL(transpose_batch_shape_list); + auto transpose_col_indices_shape = input_args[kInputIndex3]->BuildShape(); + MS_EXCEPTION_IF_NULL(transpose_col_indices_shape); + abstract::ShapePtr transpose_col_indices_shape_list = transpose_col_indices_shape->cast(); + MS_EXCEPTION_IF_NULL(transpose_col_indices_shape_list); + auto transpose_values_shape = input_args[kInputIndex4]->BuildShape(); + MS_EXCEPTION_IF_NULL(transpose_values_shape); + abstract::ShapePtr transpose_values_shape_list = transpose_values_shape->cast(); + MS_EXCEPTION_IF_NULL(transpose_values_shape_list); + if (input_args[kInputIndex0]->isa() && + !input_args[kInputIndex0]->BuildValue()->isa() && + !input_args[kInputIndex0]->BuildValue()->isa()) { + auto dense_shape = input_args[kInputIndex0]->cast(); + MS_EXCEPTION_IF_NULL(dense_shape); + auto dense_shape_ptr = dense_shape->BuildValue(); + MS_EXCEPTION_IF_NULL(dense_shape_ptr); + auto dense_shape_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue("dense_shape", dense_shape_ptr, prim_name); + ShapeVector transpose_row_pointers_shape = {0}; + if (rank_x == kInputNoBatch) { + transpose_row_pointers_shape[0] = dense_shape_ptr_tensor[1] + 1; + } else { + transpose_row_pointers_shape[0] = dense_shape_ptr_tensor[0] * (dense_shape_ptr_tensor[ktwo] + 1); + } + if (transpose_row_pointers_shape[0] > max_length) { + MS_EXCEPTION(ValueError) << "For " << prim_name << "the shape of output row pointers must be " + << "less than max length: " << max_length << ", but got " + << transpose_row_pointers_shape[0] + << "! The shape of output row pointers should be reduced" + << " or max_length should be increased."; + } + ShapeVector transpose_row_pointer_min_shape = {0}; + ShapeVector transpose_row_pointer_max_shape = {max_length}; + abstract::ShapePtr transpose_row_pointers_shape_list = std::make_shared( + transpose_row_pointers_shape, transpose_row_pointer_min_shape, transpose_row_pointer_max_shape); + return std::make_shared(std::vector{ + transpose_shape_shape_list, transpose_batch_shape_list, transpose_row_pointers_shape_list, + transpose_col_indices_shape_list, transpose_values_shape_list}); + } else { + ShapeVector transpose_row_pointers_shape = {abstract::Shape::SHP_ANY}; + ShapeVector transpose_row_pointer_min_shape = {0}; + ShapeVector transpose_row_pointer_max_shape = {max_length}; + abstract::ShapePtr transpose_row_pointers_shape_list = std::make_shared( + transpose_row_pointers_shape, transpose_row_pointer_min_shape, transpose_row_pointer_max_shape); + return std::make_shared(std::vector{ + transpose_shape_shape_list, transpose_batch_shape_list, transpose_row_pointers_shape_list, + transpose_col_indices_shape_list, transpose_values_shape_list}); + } +} + +TuplePtr SparseMatrixTransposeInferType(const PrimitivePtr &prim, const std::vector &input_args) { + const std::set index_valid_types = {kInt32, kInt64}; + const std::set values_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, + kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; + auto dense_shape_type = input_args[kInputIndex0]->BuildType(); + auto batch_type = input_args[kInputIndex1]->BuildType(); + auto row_type = input_args[kInputIndex2]->BuildType(); + auto col_type = input_args[kInputIndex3]->BuildType(); + auto value_type = input_args[kInputIndex4]->BuildType(); + std::map types; + (void)types.emplace("x_dense_shape", dense_shape_type); + (void)types.emplace("x_batch_pointers", batch_type); + (void)types.emplace("x_row_pointers", row_type); + (void)types.emplace("x_col_indices", col_type); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, index_valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x_values", value_type, values_valid_types, prim->name()); + std::vector types_list = {input_args[kInputIndex0]->BuildType(), input_args[kInputIndex1]->BuildType(), + input_args[kInputIndex2]->BuildType(), input_args[kInputIndex3]->BuildType(), + input_args[kInputIndex4]->BuildType()}; + return std::make_shared(types_list); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseMatrixTranspose, BaseOperator); +AbstractBasePtr SparseMatrixTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 5; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto infer_type = SparseMatrixTransposeInferType(primitive, input_args); + auto infer_shape = SparseMatrixTransposeInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(SparseMatrixTranspose, prim::kPrimSparseMatrixTranspose, SparseMatrixTransposeInfer, + nullptr, true); +REGISTER_HOST_DEPENDS(kNameSparseMatrixTranspose, {0}); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_matrix_transpose.h b/mindspore/core/ops/sparse_matrix_transpose.h new file mode 100644 index 00000000000..a71eacc05e2 --- /dev/null +++ b/mindspore/core/ops/sparse_matrix_transpose.h @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_ +#define MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_ +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseMatrixTranspose = "SparseMatrixTranspose"; +/// \brief Return the transpose of input CSR tensor. +class MIND_API SparseMatrixTranspose : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseMatrixTranspose); + /// \brief Constructor. + SparseMatrixTranspose() : BaseOperator(kNameSparseMatrixTranspose) { + InitIOName({"x_dense_shape", "x_batch_pointers", "x_row_pointers", "x_col_indices", "x_values"}, + {"y_dense_shape", "y_batch_pointers", "y_row_pointers", "y_col_indices", "y_values"}); + } +}; + +abstract::AbstractBasePtr SparseMatrixTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseMatrixTransposePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_ diff --git a/mindspore/python/mindspore/ops/_grad_experimental/__init__.py b/mindspore/python/mindspore/ops/_grad_experimental/__init__.py index 5977263dce0..9c8211a8168 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/__init__.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse.py index 48e89cdc7c9..ea8296f55c6 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse.py @@ -4,6 +4,7 @@ from .. import functional as F from .._grad.grad_base import bprop_getters from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.sparse_ops import SparseTensorDenseAdd +from ..operations.sparse_ops import SparseMatrixTranspose @bprop_getters.register(SparseTensorDenseAdd) @@ -12,3 +13,15 @@ def get_bprop_sparse_tensor_dense_add(self): def bprop(x1_indices, x1_values, x1_shape, x2, out, dout): return (zeros_like(x1_indices), F.gather_nd(dout, x1_indices), zeros_like(x1_shape), dout,) return bprop + + +@bprop_getters.register(SparseMatrixTranspose) +def get_bprop_sparse_matrix_transpose(self): + """Grad definition for 'SparseMatrixTranspose' operation""" + sparse_transpose = SparseMatrixTranspose(conjugate=self.conjugate) + + def bprop(x_dense_shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values, out, dout): + dx = sparse_transpose(dout[0], dout[1], dout[2], dout[3], dout[4]) + dx_all = (dx[0], dx[1], dx[2], dx[3], dx[4]) + return dx_all + return bprop diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index d83860e4e83..5024017e473 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -272,3 +272,4 @@ from .pow import _pow_aicpu from .depth_to_space import _depth_to_space_aicpu from .space_to_depth import _space_to_depth_aicpu from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu +from .sparse_matrix_transpose import _sparse_matrix_transpose_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py new file mode 100644 index 00000000000..b849c482b08 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py @@ -0,0 +1,116 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseMatrixTranspose op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_matrix_transpose_op_info = AiCPURegOp("SparseMatrixTranspose") \ + .fusion_type("OPAQUE") \ + .attr("conjugate", "bool") \ + .input(0, "x_dense_shape", "required") \ + .input(1, "x_batch_pointers", "required") \ + .input(2, "x_row_pointers", "required") \ + .input(3, "x_col_indices", "required") \ + .input(4, "x_values", "required") \ + .output(0, "y_dense_shape", "required") \ + .output(1, "y_batch_pointers", "required") \ + .output(2, "y_row_pointers", "required") \ + .output(3, "y_col_indices", "required") \ + .output(4, "y_values", "required") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.C64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.C128_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.C64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(sparse_matrix_transpose_op_info) +def _sparse_matrix_transpose_aicpu(): + """SparseMatrixTranspose AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index 350194a3826..91771ab1838 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ..primitive import PrimitiveWithInfer, Primitive, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive class SparseToDense(PrimitiveWithInfer): @@ -749,3 +749,91 @@ class CSRSparseMatrixToDense(Primitive): self.init_prim_io_names( inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'], outputs=['y']) + + +class SparseMatrixTranspose(Primitive): + r""" + Return the transpose of sparse matrix or sparse matrixs. + If the sparse matrix input contains batch dimension, then output dimension will be same with the batch dimension. + The rank of sparse matrix input must be equal to `2` or `3`. + + Note: + It is assumed that all the inputs can form a legal CSR sparse matrix, otherwise this operator is not defined. + + Args: + conjugate (bool): If True, the output sparse tensor is conjugated . Default: False. + + Inputs: + - **dense_shape** (Tensor) - A 1-D Tensor, represents the shape of input sparse matrix under dense status. + Support int32, int64. The shape is :math:`(2,)` or :math:`(3,)`. + - **batch_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements number in each batch. + Support int32, int64, takes on values: :math:`(0, nnz[0], nnz[0] + nnz[1], ..., total\_nnz)`. + If there are `n` batch within input sparse matrix, the shape is :math:`(n+1)`. + - **row_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements of each row. + Support int32, int64, takes on values: + :math:`(0, num\_rows\{b\}[0], num\_rows\{b\}[0] + num\_rows\{b\}[1], ..., nnz[b])`, + for :math:`b = 0, ..., n - 1`. + If there are `n` batch within input sparse matrix and dense shape is :math:`(rows,cols)`, + the shape is :math:`((rows + 1) * n)`. + Note: num_rows{0}[0] means the non-zero elements number in the first row of first sparse matrix. + - **col_indices** (Tensor) - A 1-D Tensor, represents the column values for the given row and column index. + Support int32, int64. The shape is :math:`(M)`, + where `M` is the number of non-zero elements in all input sparse matrix. + - **values** (Tensor) - A 1-D Tensor, represents the actual values for the given row and column index. + Support BasicType. The shape is :math:`(M)`, where `M` is the number of non-zero elements in all + input sparse matrix. + + Outputs: + - **dense_shape** (Tensor) - A 1-D Tensor, represents the shape of output sparse matrix under dense status. + Support int32, int64. The shape is the same as the input sparse matrix. + - **batch_pointers** (Tensor) - A 1-D Tensor, which is the same as the input sparse matrix's batch_pointers. + - **row_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements of each row of output sparse + matrix. Support int32, int64, takes on values: + :math:`(0, num\_rows\{b\}[0], num\_rows\{b\}[0] + num\_rows\{b\}[1], ..., nnz[b])`, + for :math:`b = 0, ..., n - 1`. + If there are `n` batch within output sparse matrix and dense shape is :math:`(rows,cols)`, + the shape is :math:`((rows + 1) * n)`. + Note: num_rows{0}[0] means the non-zero elements number in the first row of first sparse matrix. + - **col_indices** (Tensor) - A 1-D Tensor, represents the column values for the given row and column index. + Support int32, int64. The shape is :math:`(M)`, + where `M` is the number of non-zero elements in all input sparse matrix. + - **values** (Tensor) - A 1-D Tensor, which is the same as the input sparse matrix's values. + + Raises: + TypeError: If dtype of `values` doesn't meet the parameter description. + TypeError: The data type of `dense_shape, batch_pointers, row_pointers, col_indices` is not int32 or int64. + ValueError: If rank of `dense_shape` is not 2 or 3. + TypeError: The input data should have the correct CSR form. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore.ops import operations as ops + >>> dense_shape = Tensor([2,3], dtype=ms.int32) + >>> batch_pointers = Tensor([0,1], dtype=ms.int32) + >>> row_pointers = Tensor([0,1,1], dtype=ms.int32) + >>> col_indices = Tensor([0], dtype=ms.int32) + >>> values = Tensor([99], dtype=ms.float32) + >>> sparse_matrix_transpose = ops.SparseMatrixTranspose() + >>> output = sparse_matrix_transpose(dense_shape, batch_pointers, row_pointers, col_indices, values) + >>> print(output[0]) + [3 2] + >>> print(output[1]) + [0 1] + >>> print(output[2]) + [0 1 1 1] + >>> print(output[3]) + [0] + >>> print(output[4]) + [99.] + """ + @prim_attr_register + def __init__(self, conjugate=False): + """Initialize SparseMatrixTranspose""" + validator.check_value_type("conjugate", conjugate, [bool], self.name) + self.add_prim_attr("max_length", 100000000) + self.init_prim_io_names(inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', + 'x_col_indices', 'x_values'], + outputs=['y_dense_shape', 'y_batch_pointers', 'y_row_pointers', + 'y_col_indices', 'y_values']) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index aeb45a38028..49e7d52934a 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -108,6 +108,7 @@ from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm from mindspore.ops.operations.sparse_ops import SparseTensorDenseMatmul from mindspore.ops.operations.sparse_ops import SparseMatrixNNZ from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd +from mindspore.ops.operations.sparse_ops import SparseMatrixTranspose from mindspore.ops.operations.other_ops import BlackmanWindow from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp from mindspore.nn.layer import normalization @@ -4022,6 +4023,22 @@ test_case_quant_ops = [ 'block': inner.Quant(80.0, 10.0, False, "Round"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float16)], 'skip': ['backward']}), + ('SparseMatrixTranspose1', { + 'block': SparseMatrixTranspose(conjugate=False), + 'desc_inputs': [Tensor(np.array([2, 4]).astype(np.int32)), + Tensor(np.array([0, 2]).astype(np.int32)), + Tensor(np.array([0, 2, 2]).astype(np.int32)), + Tensor(np.array([0, 2]).astype(np.int32)), + Tensor(np.array([5.3, 2.4]).astype(np.float32))], + 'skip': ['backward']}), + ('SparseMatrixTranspose2', { + 'block': SparseMatrixTranspose(conjugate=True), + 'desc_inputs': [Tensor(np.array([2, 4]).astype(np.int32)), + Tensor(np.array([0, 2]).astype(np.int32)), + Tensor(np.array([0, 2, 2]).astype(np.int32)), + Tensor(np.array([0, 2]).astype(np.int32)), + Tensor(np.array([5.3, 2.4]).astype(np.float32))], + 'skip': ['backward']}), ] test_case_sparse_ops = [