!32175 [feat][assistant][I48O92]Add SparseMatrixTranspose
Merge pull request !32175 from 李定维/SparseMatrixTranspose
This commit is contained in:
commit
bd82adc603
|
@ -0,0 +1,410 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_matrix_transpose_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int64_t kTwo = 2;
|
||||
constexpr int64_t kInputsNum = 5;
|
||||
constexpr int64_t kOutputsNum = 5;
|
||||
constexpr int64_t kInputIndex0 = 0;
|
||||
constexpr int64_t kInputIndex1 = 1;
|
||||
constexpr int64_t kInputIndex2 = 2;
|
||||
constexpr int64_t kInputIndex3 = 3;
|
||||
constexpr int64_t kInputIndex4 = 4;
|
||||
constexpr int64_t kOutputIndex0 = 0;
|
||||
constexpr int64_t kOutputIndex1 = 1;
|
||||
constexpr int64_t kOutputIndex2 = 2;
|
||||
constexpr int64_t kOutputIndex3 = 3;
|
||||
constexpr int64_t kOutputIndex4 = 4;
|
||||
constexpr int64_t kDenseShape0 = 0;
|
||||
constexpr int64_t kDenseShape1 = 1;
|
||||
constexpr int64_t kDenseShape2 = 2;
|
||||
constexpr int64_t kRankWithOutBatch = 2;
|
||||
constexpr int64_t kRankWithBatch = 3;
|
||||
KernelAttr AddKernel(const TypeId &ms_type1, const TypeId &ms_type2, const TypeId &ms_type3, const TypeId &ms_type4,
|
||||
const TypeId &ms_type5, const TypeId &ms_type6, const TypeId &ms_type7, const TypeId &ms_type8,
|
||||
const TypeId &ms_type9, const TypeId &ms_type10) {
|
||||
auto kernel = KernelAttr()
|
||||
.AddInputAttr(ms_type1)
|
||||
.AddInputAttr(ms_type2)
|
||||
.AddInputAttr(ms_type3)
|
||||
.AddInputAttr(ms_type4)
|
||||
.AddInputAttr(ms_type5)
|
||||
.AddOutputAttr(ms_type6)
|
||||
.AddOutputAttr(ms_type7)
|
||||
.AddOutputAttr(ms_type8)
|
||||
.AddOutputAttr(ms_type9)
|
||||
.AddOutputAttr(ms_type10);
|
||||
return kernel;
|
||||
}
|
||||
|
||||
#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \
|
||||
AddKernel(kNumberType##t1, kNumberType##t2, kNumberType##t3, kNumberType##t4, kNumberType##t5, kNumberType##t6, \
|
||||
kNumberType##t7, kNumberType##t8, kNumberType##t9, kNumberType##t10)
|
||||
|
||||
#define SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(DTYPE, VTYPEONE, VTYPETWO, inputs, outputs) \
|
||||
case (DTYPE): { \
|
||||
LaunchKernel<VTYPEONE, VTYPETWO>(inputs, outputs); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(DTYPE, VTYPEONE, VTYPETWO, inputs, outputs) \
|
||||
case (DTYPE): { \
|
||||
LaunchcomplexKernel<VTYPEONE, VTYPETWO>(inputs, outputs); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define NODE_CHECK_AND_OUTPUT_TYPE(node_) \
|
||||
do { \
|
||||
if (!node_) { \
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired."; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape) \
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape( \
|
||||
dtypes, \
|
||||
{common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex0), \
|
||||
common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex1), y_row_pointers_shape, \
|
||||
common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex3), \
|
||||
common::AnfAlgo::GetOutputInferShape(node_, kOutputIndex4)}, \
|
||||
node_.get());
|
||||
|
||||
#define BATCH_CHECK(batch, batch_pointers, kernel_name_) \
|
||||
do { \
|
||||
if (batch + 1 != batch_pointers) { \
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals" \
|
||||
<< "dense shape[0] + 1 to match the CSR form input when input has batch."; \
|
||||
} \
|
||||
} while (0);
|
||||
} // namespace
|
||||
|
||||
void SparseMatrixTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
int64_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
|
||||
int64_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_);
|
||||
std::vector<int64_t> input_dense_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex0);
|
||||
rank_x_ = input_dense_shape[0];
|
||||
if (rank_x_ != kRankWithOutBatch && rank_x_ != kRankWithBatch) {
|
||||
MS_LOG(EXCEPTION) << "For SparseMatrixTranspose,the rank must be 2 or 3, but got" << rank_x_ << "!";
|
||||
}
|
||||
conjugate = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "conjugate");
|
||||
std::vector<int64_t> x_batch_pointers_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex1);
|
||||
std::vector<int64_t> x_row_pointers_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex2);
|
||||
std::vector<int64_t> x_col_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex3);
|
||||
std::vector<int64_t> x_value_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex4);
|
||||
x_value_size_ = x_value_shape[0];
|
||||
x_batch_pointers_size_ = x_batch_pointers_shape[0];
|
||||
x_col_indice_size_ = x_col_indices_shape[0];
|
||||
x_row_pointer_size_ = x_row_pointers_shape[0];
|
||||
if (x_col_indice_size_ != x_value_size_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of col indice shape should equals "
|
||||
<< "values shape to match the CSR form input.";
|
||||
}
|
||||
node_wpt_ = kernel_node;
|
||||
indiceT_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex0);
|
||||
valueT_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputIndex4);
|
||||
}
|
||||
|
||||
bool SparseMatrixTransposeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
switch (indiceT_) {
|
||||
case kNumberTypeInt32:
|
||||
switch (valueT_) {
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt8, int32_t, int8_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt8, int32_t, uint8_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt16, int32_t, int16_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt16, int32_t, uint16_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt32, int32_t, uint32_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt64, int32_t, int64_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt64, int32_t, uint64_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat16, int32_t, float16, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat32, int32_t, float, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat64, int32_t, double, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex64, int32_t, complex64, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex128, int32_t, complex128, inputs, outputs)
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "data type of values is not included.";
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
switch (valueT_) {
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt8, int64_t, int8_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt8, int64_t, uint8_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt16, int64_t, int16_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt16, int64_t, uint16_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt32, int64_t, int32_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt32, int64_t, uint32_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeUInt64, int64_t, uint64_t, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat16, int64_t, float16, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat32, int64_t, float, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_CASE(kNumberTypeFloat64, int64_t, double, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex64, int64_t, complex64, inputs, outputs)
|
||||
SPARSE_MATRIX_TRANSPOSE_COMPUTE_COMPLEX_CASE(kNumberTypeComplex128, int64_t, complex128, inputs, outputs)
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "data type of values is not included.";
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The data type of dense_shape, batch_pointers, "
|
||||
<< "row_pointers, col_indices is not int32 or int64.";
|
||||
break;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename indiceT, typename valueT>
|
||||
bool SparseMatrixTransposeCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
indiceT *x_dense_shape = static_cast<indiceT *>(inputs[kInputIndex0]->addr);
|
||||
indiceT *x_batch_pointers = static_cast<indiceT *>(inputs[kInputIndex1]->addr);
|
||||
indiceT *x_row_pointers = static_cast<indiceT *>(inputs[kInputIndex2]->addr);
|
||||
indiceT *x_col_indices = static_cast<indiceT *>(inputs[kInputIndex3]->addr);
|
||||
valueT *x_values = static_cast<valueT *>(inputs[kInputIndex4]->addr);
|
||||
indiceT *y_dense_shape_addr = static_cast<indiceT *>(outputs[kOutputIndex0]->addr);
|
||||
indiceT *y_batch_pointers_addr = static_cast<indiceT *>(outputs[kOutputIndex1]->addr);
|
||||
indiceT *y_row_pointers_addr = static_cast<indiceT *>(outputs[kOutputIndex2]->addr);
|
||||
indiceT *y_col_indices_addr = static_cast<indiceT *>(outputs[kOutputIndex3]->addr);
|
||||
valueT *y_values_addr = static_cast<valueT *>(outputs[kOutputIndex4]->addr);
|
||||
std::vector<int64_t> y_row_pointers_shape;
|
||||
int64_t batch_pointers = x_batch_pointers_size_;
|
||||
if (rank_x_ == kRankWithBatch) {
|
||||
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape0];
|
||||
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape2];
|
||||
y_dense_shape_addr[kDenseShape2] = x_dense_shape[kDenseShape1];
|
||||
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape0] * (x_dense_shape[kDenseShape2] + 1));
|
||||
int64_t batch = x_dense_shape[kDenseShape0];
|
||||
BATCH_CHECK(batch, batch_pointers, kernel_name_)
|
||||
} else {
|
||||
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape1];
|
||||
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape0];
|
||||
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape1] + 1);
|
||||
if (batch_pointers != kTwo) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals"
|
||||
<< "2 to match the CSR form input when input has no batch.";
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < batch_pointers; ++i) {
|
||||
y_batch_pointers_addr[i] = x_batch_pointers[i];
|
||||
}
|
||||
int64_t num_rows = x_dense_shape[rank_x_ - kTwo];
|
||||
int64_t num_cols = x_dense_shape[rank_x_ - 1];
|
||||
int64_t num_batch = batch_pointers - 1;
|
||||
std::vector<int64_t> y_part_row_pointers(num_cols + 1);
|
||||
std::vector<int64_t> part_row_pointers(num_rows + 1);
|
||||
if (x_row_pointer_size_ != num_batch * (num_rows + 1)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of row pionters shape should equals"
|
||||
<< " batch*(rows + 1 ) to match the CSR form input when input has no batch.";
|
||||
}
|
||||
for (int64_t j = 0; j < num_batch; ++j) {
|
||||
int64_t n = x_batch_pointers[j + 1] - x_batch_pointers[j];
|
||||
std::vector<valueT> part_values(n);
|
||||
std::vector<indiceT> part_col_indices(n);
|
||||
std::vector<indiceT> y_part_col_indices(n);
|
||||
std::vector<valueT> y_part_values(n);
|
||||
for (int64_t i = 0; i < num_cols + 1; ++i) {
|
||||
y_part_row_pointers[i] = 0;
|
||||
}
|
||||
for (int64_t k = 0; k < num_rows + 1; ++k) {
|
||||
part_row_pointers[k] = x_row_pointers[(num_rows + 1) * j + k];
|
||||
}
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
part_values[k] = x_values[x_batch_pointers[j] + k];
|
||||
part_col_indices[k] = x_col_indices[x_batch_pointers[j] + k];
|
||||
}
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
y_part_row_pointers[part_col_indices[i] + 1] += 1;
|
||||
}
|
||||
for (int64_t i = 1; i < num_cols + 1; ++i) {
|
||||
y_part_row_pointers[i] += y_part_row_pointers[i - 1];
|
||||
}
|
||||
for (int64_t k = 0; k < num_cols + 1; ++k) {
|
||||
y_row_pointers_addr[(num_cols + 1) * j + k] = y_part_row_pointers[k];
|
||||
}
|
||||
std::vector<int64_t> current_col_count(num_cols);
|
||||
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
const int64_t row_begin = part_row_pointers[row_idx];
|
||||
const int64_t row_end = part_row_pointers[row_idx + 1];
|
||||
for (int64_t i = row_begin; i < row_end; ++i) {
|
||||
const int64_t col_idx = part_col_indices[i];
|
||||
const int64_t offset = y_part_row_pointers[col_idx] + current_col_count[col_idx];
|
||||
y_part_col_indices[offset] = row_idx;
|
||||
y_part_values[offset] = part_values[i];
|
||||
current_col_count[col_idx] += 1;
|
||||
}
|
||||
}
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
y_values_addr[x_batch_pointers[j] + k] = y_part_values[k];
|
||||
y_col_indices_addr[x_batch_pointers[j] + k] = y_part_col_indices[k];
|
||||
}
|
||||
}
|
||||
auto node_ = node_wpt_.lock();
|
||||
NODE_CHECK_AND_OUTPUT_TYPE(node_)
|
||||
int64_t output_nm = common::AnfAlgo::GetOutputTensorNum(node_);
|
||||
std::vector<TypeId> dtypes(output_nm);
|
||||
for (int64_t i = 0; i < output_nm; i++) {
|
||||
dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i);
|
||||
}
|
||||
SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape)
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename indiceT, typename valueT>
|
||||
bool SparseMatrixTransposeCpuKernelMod::LaunchcomplexKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
indiceT *x_dense_shape = reinterpret_cast<indiceT *>(inputs[kInputIndex0]->addr);
|
||||
indiceT *x_batch_pointers = reinterpret_cast<indiceT *>(inputs[kInputIndex1]->addr);
|
||||
indiceT *x_row_pointers = reinterpret_cast<indiceT *>(inputs[kInputIndex2]->addr);
|
||||
indiceT *x_col_indices = reinterpret_cast<indiceT *>(inputs[kInputIndex3]->addr);
|
||||
valueT *x_values = reinterpret_cast<valueT *>(inputs[kInputIndex4]->addr);
|
||||
indiceT *y_dense_shape_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex0]->addr);
|
||||
indiceT *y_batch_pointers_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex1]->addr);
|
||||
indiceT *y_row_pointers_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex2]->addr);
|
||||
indiceT *y_col_indices_addr = reinterpret_cast<indiceT *>(outputs[kOutputIndex3]->addr);
|
||||
valueT *y_values_addr = reinterpret_cast<valueT *>(outputs[kOutputIndex4]->addr);
|
||||
std::vector<int64_t> y_row_pointers_shape;
|
||||
int64_t batch_pointers = x_batch_pointers_size_;
|
||||
if (rank_x_ == kRankWithBatch) {
|
||||
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape0];
|
||||
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape2];
|
||||
y_dense_shape_addr[kDenseShape2] = x_dense_shape[kDenseShape1];
|
||||
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape0] * (x_dense_shape[kDenseShape2] + 1));
|
||||
int64_t batch = x_dense_shape[kDenseShape0];
|
||||
BATCH_CHECK(batch, batch_pointers, kernel_name_)
|
||||
} else {
|
||||
y_dense_shape_addr[kDenseShape0] = x_dense_shape[kDenseShape1];
|
||||
y_dense_shape_addr[kDenseShape1] = x_dense_shape[kDenseShape0];
|
||||
y_row_pointers_shape.push_back(x_dense_shape[kDenseShape1] + 1);
|
||||
if (batch_pointers != kTwo) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of batch pionters shape should equals"
|
||||
<< "2 to match the CSR form input when input has no batch.";
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < batch_pointers; ++i) {
|
||||
y_batch_pointers_addr[i] = x_batch_pointers[i];
|
||||
}
|
||||
int64_t num_rows = x_dense_shape[rank_x_ - kTwo];
|
||||
int64_t num_cols = x_dense_shape[rank_x_ - 1];
|
||||
int64_t num_batch = batch_pointers - 1;
|
||||
std::vector<int64_t> y_part_row_pointers(num_cols + 1);
|
||||
std::vector<int64_t> part_row_pointers(num_rows + 1);
|
||||
if (x_row_pointer_size_ != num_batch * (num_rows + 1)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input of row pionters shape should equals"
|
||||
<< " batch*(rows + 1 ) to match the CSR form input when input has no batch.";
|
||||
}
|
||||
for (int64_t j = 0; j < num_batch; ++j) {
|
||||
int64_t n = x_batch_pointers[j + 1] - x_batch_pointers[j];
|
||||
std::vector<valueT> part_values(n);
|
||||
std::vector<indiceT> part_col_indices(n);
|
||||
std::vector<indiceT> y_part_col_indices(n);
|
||||
std::vector<valueT> y_part_values(n);
|
||||
for (int64_t i = 0; i < num_cols + 1; ++i) {
|
||||
y_part_row_pointers[i] = 0;
|
||||
}
|
||||
for (int64_t k = 0; k < num_rows + 1; ++k) {
|
||||
part_row_pointers[k] = x_row_pointers[(num_rows + 1) * j + k];
|
||||
}
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
part_values[k] = x_values[x_batch_pointers[j] + k];
|
||||
part_col_indices[k] = x_col_indices[x_batch_pointers[j] + k];
|
||||
}
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
y_part_row_pointers[part_col_indices[i] + 1] += 1;
|
||||
}
|
||||
for (int64_t i = 1; i < num_cols + 1; ++i) {
|
||||
y_part_row_pointers[i] += y_part_row_pointers[i - 1];
|
||||
}
|
||||
for (int64_t k = 0; k < num_cols + 1; ++k) {
|
||||
y_row_pointers_addr[(num_cols + 1) * j + k] = y_part_row_pointers[k];
|
||||
}
|
||||
std::vector<int64_t> current_col_count(num_cols);
|
||||
for (int64_t row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
const int64_t row_begin = part_row_pointers[row_idx];
|
||||
const int64_t row_end = part_row_pointers[row_idx + 1];
|
||||
for (int64_t i = row_begin; i < row_end; ++i) {
|
||||
const int64_t col_idx = part_col_indices[i];
|
||||
const int64_t offset = y_part_row_pointers[col_idx] + current_col_count[col_idx];
|
||||
y_part_col_indices[offset] = row_idx;
|
||||
y_part_values[offset] = part_values[i];
|
||||
current_col_count[col_idx] += 1;
|
||||
}
|
||||
}
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
y_values_addr[x_batch_pointers[j] + k] = y_part_values[k];
|
||||
y_col_indices_addr[x_batch_pointers[j] + k] = y_part_col_indices[k];
|
||||
}
|
||||
}
|
||||
if (conjugate == true) {
|
||||
for (int64_t i = 0; i < x_value_size_; ++i) {
|
||||
y_values_addr[i] = std::conj(y_values_addr[i]);
|
||||
}
|
||||
}
|
||||
auto node_ = node_wpt_.lock();
|
||||
NODE_CHECK_AND_OUTPUT_TYPE(node_)
|
||||
int64_t output_nm = common::AnfAlgo::GetOutputTensorNum(node_);
|
||||
std::vector<TypeId> dtypes(output_nm);
|
||||
for (int64_t i = 0; i < output_nm; i++) {
|
||||
dtypes[i] = AnfAlgo::GetOutputDeviceDataType(node_, i);
|
||||
}
|
||||
SET_OUTPUT_SHAPE_AND_TYPE(node_, dtypes, y_row_pointers_shape)
|
||||
return true;
|
||||
}
|
||||
std::vector<KernelAttr> SparseMatrixTransposeCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Int8, Int32, Int32, Int32, Int32, Int8),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt8, Int32, Int32, Int32, Int32, UInt8),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Int16, Int32, Int32, Int32, Int32, Int16),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt16, Int32, Int32, Int32, Int32, UInt16),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32, Int32),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Int64, Int32, Int32, Int32, Int32, Int64),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt32, Int32, Int32, Int32, Int32, UInt32),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, UInt64, Int32, Int32, Int32, Int32, UInt64),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Float16, Int32, Int32, Int32, Int32, Float16),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Float32, Int32, Int32, Int32, Int32, Float32),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Float64, Int32, Int32, Int32, Int32, Float64),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Complex64, Int32, Int32, Int32, Int32, Complex64),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32, Complex128, Int32, Int32, Int32, Int32, Complex128),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Int8, Int64, Int64, Int64, Int64, Int8),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt8, Int64, Int64, Int64, Int64, UInt8),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Int16, Int64, Int64, Int64, Int64, Int16),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt16, Int64, Int64, Int64, Int64, UInt16),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Int32, Int64, Int64, Int64, Int64, Int32),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt32, Int64, Int64, Int64, Int64, UInt32),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, UInt64, Int64, Int64, Int64, Int64, UInt64),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Float16, Int64, Int64, Int64, Int64, Float16),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Float32, Int64, Int64, Int64, Int64, Float32),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Float64, Int64, Int64, Int64, Int64, Float64),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Complex64, Int64, Int64, Int64, Int64, Complex64),
|
||||
ADD_KERNEL(Int64, Int64, Int64, Int64, Complex128, Int64, Int64, Int64, Int64, Complex128)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseMatrixTranspose, SparseMatrixTransposeCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
class SparseMatrixTransposeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SparseMatrixTransposeCpuKernelMod() = default;
|
||||
~SparseMatrixTransposeCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename indiceT, typename valueT>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename indiceT, typename valueT>
|
||||
bool LaunchcomplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
int64_t x_batch_pointers_size_;
|
||||
int64_t x_value_size_;
|
||||
int64_t x_col_indice_size_;
|
||||
int64_t x_row_pointer_size_;
|
||||
int64_t rank_x_;
|
||||
bool conjugate;
|
||||
TypeId indiceT_{kTypeUnknown};
|
||||
TypeId valueT_{kTypeUnknown};
|
||||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_MATRIX_TRANSPOSE_CPU_KERNEL_H_
|
|
@ -279,6 +279,7 @@ constexpr auto kSparseMatrixAdd = "SparseMatrixAdd";
|
|||
constexpr auto kSparseAdd = "SparseAdd";
|
||||
constexpr auto kSparseConcat = "SparseConcat";
|
||||
constexpr auto kSparseMatrixNNZ = "SparseMatrixNNZ";
|
||||
constexpr auto kSparseMatrixTranspose = "SparseMatrixTranspose";
|
||||
|
||||
// Sparse Grad ops
|
||||
constexpr auto kSparseAddGrad = "SparseAddGrad";
|
||||
|
@ -886,6 +887,7 @@ GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToSparseTensor, std::make_shared<Prim
|
|||
GVAR_DEF(PrimitivePtr, kPrimSparseConcat, std::make_shared<Primitive>(kSparseConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixNNZ, std::make_shared<Primitive>(kSparseMatrixNNZ));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToDense, std::make_shared<Primitive>("CSRSparseMatrixToDense"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixTranspose, std::make_shared<Primitive>(kSparseMatrixTranspose));
|
||||
|
||||
// Sparse Grad ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared<Primitive>(kSparseAddGrad));
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ops/sparse_matrix_transpose.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr SparseMatrixTransposeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kInputNoBatch = 2;
|
||||
const int64_t kInputWithBatch = 3;
|
||||
const int64_t ktwo = 2;
|
||||
std::vector<int64_t> dense_shape_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
const int64_t rank_x = dense_shape_shape[0];
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
const int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
std::vector<int64_t> batch_pointers_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
std::vector<int64_t> row_pointers_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
std::vector<int64_t> col_indices_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
std::vector<int64_t> values_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
if (rank_x != kInputNoBatch && rank_x != kInputWithBatch) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the rank of input must be 2 or 3, but got "
|
||||
<< dense_shape_shape.size() << "!";
|
||||
}
|
||||
if (batch_pointers_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input batch pointers must be 1-D, but got "
|
||||
<< batch_pointers_shape.size() << "-D.";
|
||||
}
|
||||
if (row_pointers_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input row pointers must be 1-D, but got "
|
||||
<< row_pointers_shape.size() << "-D.";
|
||||
}
|
||||
if (col_indices_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input col indices must be 1-D, but got "
|
||||
<< col_indices_shape.size() << "-D.";
|
||||
}
|
||||
if (values_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ",the shape of input col indices must be 1-D, but got "
|
||||
<< col_indices_shape.size() << "-D.";
|
||||
}
|
||||
auto transpose_shape_shape = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(transpose_shape_shape);
|
||||
abstract::ShapePtr transpose_shape_shape_list = transpose_shape_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(transpose_shape_shape_list);
|
||||
auto transpose_batch_shape = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(transpose_batch_shape);
|
||||
abstract::ShapePtr transpose_batch_shape_list = transpose_batch_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(transpose_batch_shape_list);
|
||||
auto transpose_col_indices_shape = input_args[kInputIndex3]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(transpose_col_indices_shape);
|
||||
abstract::ShapePtr transpose_col_indices_shape_list = transpose_col_indices_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(transpose_col_indices_shape_list);
|
||||
auto transpose_values_shape = input_args[kInputIndex4]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(transpose_values_shape);
|
||||
abstract::ShapePtr transpose_values_shape_list = transpose_values_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(transpose_values_shape_list);
|
||||
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
|
||||
!input_args[kInputIndex0]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex0]->BuildValue()->isa<None>()) {
|
||||
auto dense_shape = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape);
|
||||
auto dense_shape_ptr = dense_shape->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_ptr);
|
||||
auto dense_shape_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue("dense_shape", dense_shape_ptr, prim_name);
|
||||
ShapeVector transpose_row_pointers_shape = {0};
|
||||
if (rank_x == kInputNoBatch) {
|
||||
transpose_row_pointers_shape[0] = dense_shape_ptr_tensor[1] + 1;
|
||||
} else {
|
||||
transpose_row_pointers_shape[0] = dense_shape_ptr_tensor[0] * (dense_shape_ptr_tensor[ktwo] + 1);
|
||||
}
|
||||
if (transpose_row_pointers_shape[0] > max_length) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << "the shape of output row pointers must be "
|
||||
<< "less than max length: " << max_length << ", but got "
|
||||
<< transpose_row_pointers_shape[0]
|
||||
<< "! The shape of output row pointers should be reduced"
|
||||
<< " or max_length should be increased.";
|
||||
}
|
||||
ShapeVector transpose_row_pointer_min_shape = {0};
|
||||
ShapeVector transpose_row_pointer_max_shape = {max_length};
|
||||
abstract::ShapePtr transpose_row_pointers_shape_list = std::make_shared<abstract::Shape>(
|
||||
transpose_row_pointers_shape, transpose_row_pointer_min_shape, transpose_row_pointer_max_shape);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
transpose_shape_shape_list, transpose_batch_shape_list, transpose_row_pointers_shape_list,
|
||||
transpose_col_indices_shape_list, transpose_values_shape_list});
|
||||
} else {
|
||||
ShapeVector transpose_row_pointers_shape = {abstract::Shape::SHP_ANY};
|
||||
ShapeVector transpose_row_pointer_min_shape = {0};
|
||||
ShapeVector transpose_row_pointer_max_shape = {max_length};
|
||||
abstract::ShapePtr transpose_row_pointers_shape_list = std::make_shared<abstract::Shape>(
|
||||
transpose_row_pointers_shape, transpose_row_pointer_min_shape, transpose_row_pointer_max_shape);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
transpose_shape_shape_list, transpose_batch_shape_list, transpose_row_pointers_shape_list,
|
||||
transpose_col_indices_shape_list, transpose_values_shape_list});
|
||||
}
|
||||
}
|
||||
|
||||
TuplePtr SparseMatrixTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> index_valid_types = {kInt32, kInt64};
|
||||
const std::set<TypePtr> values_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
auto dense_shape_type = input_args[kInputIndex0]->BuildType();
|
||||
auto batch_type = input_args[kInputIndex1]->BuildType();
|
||||
auto row_type = input_args[kInputIndex2]->BuildType();
|
||||
auto col_type = input_args[kInputIndex3]->BuildType();
|
||||
auto value_type = input_args[kInputIndex4]->BuildType();
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x_dense_shape", dense_shape_type);
|
||||
(void)types.emplace("x_batch_pointers", batch_type);
|
||||
(void)types.emplace("x_row_pointers", row_type);
|
||||
(void)types.emplace("x_col_indices", col_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, index_valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_values", value_type, values_valid_types, prim->name());
|
||||
std::vector<TypePtr> types_list = {input_args[kInputIndex0]->BuildType(), input_args[kInputIndex1]->BuildType(),
|
||||
input_args[kInputIndex2]->BuildType(), input_args[kInputIndex3]->BuildType(),
|
||||
input_args[kInputIndex4]->BuildType()};
|
||||
return std::make_shared<Tuple>(types_list);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseMatrixTranspose, BaseOperator);
|
||||
AbstractBasePtr SparseMatrixTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 5;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = SparseMatrixTransposeInferType(primitive, input_args);
|
||||
auto infer_shape = SparseMatrixTransposeInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseMatrixTranspose, prim::kPrimSparseMatrixTranspose, SparseMatrixTransposeInfer,
|
||||
nullptr, true);
|
||||
REGISTER_HOST_DEPENDS(kNameSparseMatrixTranspose, {0});
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseMatrixTranspose = "SparseMatrixTranspose";
|
||||
/// \brief Return the transpose of input CSR tensor.
|
||||
class MIND_API SparseMatrixTranspose : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseMatrixTranspose);
|
||||
/// \brief Constructor.
|
||||
SparseMatrixTranspose() : BaseOperator(kNameSparseMatrixTranspose) {
|
||||
InitIOName({"x_dense_shape", "x_batch_pointers", "x_row_pointers", "x_col_indices", "x_values"},
|
||||
{"y_dense_shape", "y_batch_pointers", "y_row_pointers", "y_col_indices", "y_values"});
|
||||
}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SparseMatrixTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSparseMatrixTransposePtr = std::shared_ptr<SparseMatrixTranspose>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_MATRIX_TRANSPOSE_H_
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -4,6 +4,7 @@ from .. import functional as F
|
|||
from .._grad.grad_base import bprop_getters
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.sparse_ops import SparseTensorDenseAdd
|
||||
from ..operations.sparse_ops import SparseMatrixTranspose
|
||||
|
||||
|
||||
@bprop_getters.register(SparseTensorDenseAdd)
|
||||
|
@ -12,3 +13,15 @@ def get_bprop_sparse_tensor_dense_add(self):
|
|||
def bprop(x1_indices, x1_values, x1_shape, x2, out, dout):
|
||||
return (zeros_like(x1_indices), F.gather_nd(dout, x1_indices), zeros_like(x1_shape), dout,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(SparseMatrixTranspose)
|
||||
def get_bprop_sparse_matrix_transpose(self):
|
||||
"""Grad definition for 'SparseMatrixTranspose' operation"""
|
||||
sparse_transpose = SparseMatrixTranspose(conjugate=self.conjugate)
|
||||
|
||||
def bprop(x_dense_shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values, out, dout):
|
||||
dx = sparse_transpose(dout[0], dout[1], dout[2], dout[3], dout[4])
|
||||
dx_all = (dx[0], dx[1], dx[2], dx[3], dx[4])
|
||||
return dx_all
|
||||
return bprop
|
||||
|
|
|
@ -272,3 +272,4 @@ from .pow import _pow_aicpu
|
|||
from .depth_to_space import _depth_to_space_aicpu
|
||||
from .space_to_depth import _space_to_depth_aicpu
|
||||
from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
|
||||
from .sparse_matrix_transpose import _sparse_matrix_transpose_aicpu
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseMatrixTranspose op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sparse_matrix_transpose_op_info = AiCPURegOp("SparseMatrixTranspose") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("conjugate", "bool") \
|
||||
.input(0, "x_dense_shape", "required") \
|
||||
.input(1, "x_batch_pointers", "required") \
|
||||
.input(2, "x_row_pointers", "required") \
|
||||
.input(3, "x_col_indices", "required") \
|
||||
.input(4, "x_values", "required") \
|
||||
.output(0, "y_dense_shape", "required") \
|
||||
.output(1, "y_batch_pointers", "required") \
|
||||
.output(2, "y_row_pointers", "required") \
|
||||
.output(3, "y_col_indices", "required") \
|
||||
.output(4, "y_values", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.C128_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_matrix_transpose_op_info)
|
||||
def _sparse_matrix_transpose_aicpu():
|
||||
"""SparseMatrixTranspose AiCPU register"""
|
||||
return
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, Primitive, prim_attr_register
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
|
||||
|
||||
class SparseToDense(PrimitiveWithInfer):
|
||||
|
@ -749,3 +749,91 @@ class CSRSparseMatrixToDense(Primitive):
|
|||
self.init_prim_io_names(
|
||||
inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'],
|
||||
outputs=['y'])
|
||||
|
||||
|
||||
class SparseMatrixTranspose(Primitive):
|
||||
r"""
|
||||
Return the transpose of sparse matrix or sparse matrixs.
|
||||
If the sparse matrix input contains batch dimension, then output dimension will be same with the batch dimension.
|
||||
The rank of sparse matrix input must be equal to `2` or `3`.
|
||||
|
||||
Note:
|
||||
It is assumed that all the inputs can form a legal CSR sparse matrix, otherwise this operator is not defined.
|
||||
|
||||
Args:
|
||||
conjugate (bool): If True, the output sparse tensor is conjugated . Default: False.
|
||||
|
||||
Inputs:
|
||||
- **dense_shape** (Tensor) - A 1-D Tensor, represents the shape of input sparse matrix under dense status.
|
||||
Support int32, int64. The shape is :math:`(2,)` or :math:`(3,)`.
|
||||
- **batch_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements number in each batch.
|
||||
Support int32, int64, takes on values: :math:`(0, nnz[0], nnz[0] + nnz[1], ..., total\_nnz)`.
|
||||
If there are `n` batch within input sparse matrix, the shape is :math:`(n+1)`.
|
||||
- **row_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements of each row.
|
||||
Support int32, int64, takes on values:
|
||||
:math:`(0, num\_rows\{b\}[0], num\_rows\{b\}[0] + num\_rows\{b\}[1], ..., nnz[b])`,
|
||||
for :math:`b = 0, ..., n - 1`.
|
||||
If there are `n` batch within input sparse matrix and dense shape is :math:`(rows,cols)`,
|
||||
the shape is :math:`((rows + 1) * n)`.
|
||||
Note: num_rows{0}[0] means the non-zero elements number in the first row of first sparse matrix.
|
||||
- **col_indices** (Tensor) - A 1-D Tensor, represents the column values for the given row and column index.
|
||||
Support int32, int64. The shape is :math:`(M)`,
|
||||
where `M` is the number of non-zero elements in all input sparse matrix.
|
||||
- **values** (Tensor) - A 1-D Tensor, represents the actual values for the given row and column index.
|
||||
Support BasicType. The shape is :math:`(M)`, where `M` is the number of non-zero elements in all
|
||||
input sparse matrix.
|
||||
|
||||
Outputs:
|
||||
- **dense_shape** (Tensor) - A 1-D Tensor, represents the shape of output sparse matrix under dense status.
|
||||
Support int32, int64. The shape is the same as the input sparse matrix.
|
||||
- **batch_pointers** (Tensor) - A 1-D Tensor, which is the same as the input sparse matrix's batch_pointers.
|
||||
- **row_pointers** (Tensor) - A 1-D Tensor, represents the non-zero elements of each row of output sparse
|
||||
matrix. Support int32, int64, takes on values:
|
||||
:math:`(0, num\_rows\{b\}[0], num\_rows\{b\}[0] + num\_rows\{b\}[1], ..., nnz[b])`,
|
||||
for :math:`b = 0, ..., n - 1`.
|
||||
If there are `n` batch within output sparse matrix and dense shape is :math:`(rows,cols)`,
|
||||
the shape is :math:`((rows + 1) * n)`.
|
||||
Note: num_rows{0}[0] means the non-zero elements number in the first row of first sparse matrix.
|
||||
- **col_indices** (Tensor) - A 1-D Tensor, represents the column values for the given row and column index.
|
||||
Support int32, int64. The shape is :math:`(M)`,
|
||||
where `M` is the number of non-zero elements in all input sparse matrix.
|
||||
- **values** (Tensor) - A 1-D Tensor, which is the same as the input sparse matrix's values.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `values` doesn't meet the parameter description.
|
||||
TypeError: The data type of `dense_shape, batch_pointers, row_pointers, col_indices` is not int32 or int64.
|
||||
ValueError: If rank of `dense_shape` is not 2 or 3.
|
||||
TypeError: The input data should have the correct CSR form.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import operations as ops
|
||||
>>> dense_shape = Tensor([2,3], dtype=ms.int32)
|
||||
>>> batch_pointers = Tensor([0,1], dtype=ms.int32)
|
||||
>>> row_pointers = Tensor([0,1,1], dtype=ms.int32)
|
||||
>>> col_indices = Tensor([0], dtype=ms.int32)
|
||||
>>> values = Tensor([99], dtype=ms.float32)
|
||||
>>> sparse_matrix_transpose = ops.SparseMatrixTranspose()
|
||||
>>> output = sparse_matrix_transpose(dense_shape, batch_pointers, row_pointers, col_indices, values)
|
||||
>>> print(output[0])
|
||||
[3 2]
|
||||
>>> print(output[1])
|
||||
[0 1]
|
||||
>>> print(output[2])
|
||||
[0 1 1 1]
|
||||
>>> print(output[3])
|
||||
[0]
|
||||
>>> print(output[4])
|
||||
[99.]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, conjugate=False):
|
||||
"""Initialize SparseMatrixTranspose"""
|
||||
validator.check_value_type("conjugate", conjugate, [bool], self.name)
|
||||
self.add_prim_attr("max_length", 100000000)
|
||||
self.init_prim_io_names(inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers',
|
||||
'x_col_indices', 'x_values'],
|
||||
outputs=['y_dense_shape', 'y_batch_pointers', 'y_row_pointers',
|
||||
'y_col_indices', 'y_values'])
|
||||
|
|
|
@ -108,6 +108,7 @@ from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm
|
|||
from mindspore.ops.operations.sparse_ops import SparseTensorDenseMatmul
|
||||
from mindspore.ops.operations.sparse_ops import SparseMatrixNNZ
|
||||
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
|
||||
from mindspore.ops.operations.sparse_ops import SparseMatrixTranspose
|
||||
from mindspore.ops.operations.other_ops import BlackmanWindow
|
||||
from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp
|
||||
from mindspore.nn.layer import normalization
|
||||
|
@ -4022,6 +4023,22 @@ test_case_quant_ops = [
|
|||
'block': inner.Quant(80.0, 10.0, False, "Round"),
|
||||
'desc_inputs': [Tensor([100.0, 200.0], mstype.float16)],
|
||||
'skip': ['backward']}),
|
||||
('SparseMatrixTranspose1', {
|
||||
'block': SparseMatrixTranspose(conjugate=False),
|
||||
'desc_inputs': [Tensor(np.array([2, 4]).astype(np.int32)),
|
||||
Tensor(np.array([0, 2]).astype(np.int32)),
|
||||
Tensor(np.array([0, 2, 2]).astype(np.int32)),
|
||||
Tensor(np.array([0, 2]).astype(np.int32)),
|
||||
Tensor(np.array([5.3, 2.4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('SparseMatrixTranspose2', {
|
||||
'block': SparseMatrixTranspose(conjugate=True),
|
||||
'desc_inputs': [Tensor(np.array([2, 4]).astype(np.int32)),
|
||||
Tensor(np.array([0, 2]).astype(np.int32)),
|
||||
Tensor(np.array([0, 2, 2]).astype(np.int32)),
|
||||
Tensor(np.array([0, 2]).astype(np.int32)),
|
||||
Tensor(np.array([5.3, 2.4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
test_case_sparse_ops = [
|
||||
|
|
Loading…
Reference in New Issue