SparseSlice and Grad

This commit is contained in:
YR0717 2022-07-27 15:07:11 +08:00 committed by zhudongyao1
parent e94c4d627e
commit 9f42d10b56
6 changed files with 557 additions and 27 deletions

View File

@ -321,6 +321,8 @@ constexpr auto kSparseMatrixOrderingAMD = "SparseMatrixOrderingAMD";
// Sparse Grad ops
constexpr auto kSparseAddGrad = "SparseAddGrad";
constexpr auto kSparseTensorDenseAdd = "SparseTensorDenseAdd";
constexpr auto kSparseSlice = "SparseSlice";
constexpr auto kSparseSliceGrad = "SparseSliceGrad";
// Meta Function Graph
constexpr auto kJ = "J";
@ -971,6 +973,8 @@ GVAR_DEF(PrimitivePtr, kPrimSparseMatrixOrderingAMD, std::make_shared<Primitive>
// Sparse Grad ops
GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared<Primitive>(kSparseAddGrad));
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseAdd, std::make_shared<Primitive>(kSparseTensorDenseAdd));
GVAR_DEF(PrimitivePtr, kPrimSparseSlice, std::make_shared<Primitive>(kSparseSlice));
GVAR_DEF(PrimitivePtr, kPrimSparseSliceGrad, std::make_shared<Primitive>(kSparseSliceGrad));
// TensorList
GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor"));

View File

@ -0,0 +1,127 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/sparse_slice_grad.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <map>
#include <vector>
#include "abstract/param_validator.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
enum DimNum : size_t {
dim0Num = 0,
dim1Num,
dim2Num,
};
void CheckInputTensor(const std::vector<AbstractBasePtr> &input_args) {
auto backprop_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto start_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto new_indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
if (indices_shape.size() != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSliceGrad, indices should be a 2-D tensor"
<< ", while input_indices dim num is " << indices_shape.size() << ".";
}
if (indices_shape[1] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSliceGrad, indices shape should be (2, n)"
<< ", while input_indices shape dim0 is " << indices_shape[0] << ".";
}
if (backprop_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For SparseSliceGrad, backprop_val_grad should be a 1-D tensor"
<< ", while input_backprop_val_grad dim num is " << backprop_shape.size() << ".";
}
if (start_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSliceGrad, start should be a 2-D tensor"
<< ", while dim num is " << start_shape.size() << ".";
}
if (new_indices_shape.size() != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSliceGrad, new_indices should be a 2-D tensor"
<< ", while input_new_indices dim num is " << new_indices_shape.size() << ".";
}
if (new_indices_shape[1] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSliceGrad, new_indices shape should be (2, n)"
<< ", while new_indices_indices shape dim0 is " << new_indices_shape[0] << ".";
}
}
bool IsDynamic(const ShapeVector &shape) {
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
return true;
}
return false;
}
abstract::ShapePtr SparseSliceGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto new_indices_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("SparseSliceGrad", input_args, 3);
MS_EXCEPTION_IF_NULL(new_indices_shape_ptr);
auto new_indices_shape = new_indices_shape_ptr->shape();
auto backprop_shape_ptr =CheckAndConvertUtils::GetTensorInputShape("SparseSliceGrad", input_args, 0);
auto backprop_shape = backprop_shape_ptr->shape();
if (!(IsDynamic(backprop_shape)) && !(IsDynamic(new_indices_shape))) {
CheckInputTensor(input_args);
} else {
backprop_shape = {abstract::Shape::SHP_ANY};
new_indices_shape = {abstract::Shape::SHP_ANY, 2};
}
auto indices_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto indices_shape = indices_shape_map[kShape];
int64_t output_shape = indices_shape[0];
std::vector<int64_t> output_values_shape = {output_shape};
return std::make_shared<abstract::Shape>(output_values_shape);
}
TypePtr SparseSliceGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckTensorTypeValid("backprop_val_grad", input_args[kInputIndex0]->BuildType(),
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
prim_name);
std::map<std::string, TypePtr> in_args = {{"indices", input_args[kInputIndex1]->BuildType()},
{"start", input_args[kInputIndex2]->BuildType()},
{"new_indices", input_args[kInputIndex3]->BuildType()}};
(void)CheckAndConvertUtils::CheckTensorTypeSame(in_args, {kInt64, kInt32}, prim_name);
auto output_type = input_args[kInputIndex0]->BuildType();
return output_type;
}
} // namespace
AbstractBasePtr SparseSliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 4;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = SparseSliceGradInferType(primitive, input_args);
auto infer_shape = SparseSliceGradInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSliceGrad, prim::kPrimSparseSliceGrad, SparseSliceGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SLICE_GRAD_H_
#define MINDSPORE_CORE_OPS_SPARSE_SLICE_GRAD_H_
#include <vector>
#include <set>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSliceGrad = "SparseSliceGrad";
class MIND_API SparseSliceGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSliceGrad);
SparseSliceGrad() : BaseOperator(kNameSparseSliceGrad) {
InitIOName(
{"backprop_val_grad", "indices", "start", "new_indices"},
{"y_grad"});
}
void Init() {}
};
AbstractBasePtr SparseSliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SLICE_GRAD_H_

View File

@ -0,0 +1,225 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include <climits>
#include "ops/sparse_slice.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
enum DimNum : size_t {
dim0Num = 0,
dim1Num,
dim2Num,
};
void CheckInputTensor(const std::vector<AbstractBasePtr> &input_args) {
auto indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto values_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto start_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
if (indices_shape.size() != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice, indices should be a 2-D tensor"
<< ", while input_indices dim num is " << indices_shape.size() << ".";
}
if (indices_shape[1] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice, indices shape should be (2, n)"
<< ", while input_indices shape dim0 is " << indices_shape[0] << ".";
}
if (values_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice, values should be a 1-D tensor"
<< ", while input_values dim num is " << values_shape.size() << ".";
}
if (indices_shape[0] != values_shape[0]) {
MS_EXCEPTION(ValueError) << "For SparseSlice"
<< ", dim1 size of `indices` and dim0 size of `values` should be the same"
<< " while indices_shape dim1 size is " << indices_shape[1]
<< ", values_shape dim0 size is " << values_shape[0] << ".";
}
if (shape_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice"
<< ", shape should be a 1-D tensor, while input_shape dim num is " << shape_shape.size()
<< ".";
}
if (shape_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice"
<< ", the shape of input shape should be [2] but got shape [" << shape_shape[0] << "].";
}
if (start_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice, start should be a 2-D tensor"
<< ", while dim num is " << start_shape.size() << ".";
}
if (size_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For SparseSlice, size should be a 2-D tensor"
<< ", while dim num is " << size_shape.size() << ".";
}
}
template <typename T>
void IndicesBoundCheck(T *indices_val, size_t indices_num, T *shape_val, std::string name) {
if (shape_val[0] <= 0 || shape_val[1] <= 0) {
MS_EXCEPTION(ValueError) << "For SparseSlice, " << name << "_shape should be positive, "
<< "while got shape [" << shape_val[0] << ", " << shape_val[1] << "].";
}
size_t half_num = indices_num / dim2Num;
for (size_t i = 0; i < half_num; i++) {
if ((indices_val[i] < 0) || (indices_val[i] >= shape_val[0])) {
MS_EXCEPTION(ValueError) << "For SparseSlice, " << name << "_indices row index should between [0, " << shape_val[0]
<< "], while got row index " << indices_val[i] << ".";
}
if ((indices_val[i + half_num] < 0) || (indices_val[i + half_num] >= shape_val[1])) {
MS_EXCEPTION(ValueError) << "For SparseSlice, " << name << "_indices col index should between [0, " << shape_val[1]
<< "], while got col index " << indices_val[i + half_num] << ".";
}
}
}
void CheckIndices(const std::vector<AbstractBasePtr> &input_args) {
auto x1_indices_abstract = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x1_indices_abstract);
auto x1_indices_value_ptr = x1_indices_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x1_indices_value_ptr);
auto x1_indices_tensor = x1_indices_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x1_indices_tensor);
auto x1_indices_type = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(x1_indices_type);
auto x1_indices_type_id = x1_indices_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(x1_indices_type_id);
auto x1_indices_type_element = x1_indices_type_id->element();
MS_EXCEPTION_IF_NULL(x1_indices_type_element);
auto x1_shape_abstract = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x1_shape_abstract);
auto x1_shape_value_ptr = x1_shape_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x1_shape_value_ptr);
auto x1_shape_tensor = x1_shape_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x1_shape_tensor);
if (x1_indices_type_element->type_id() == kNumberTypeInt32) {
IndicesBoundCheck<int32_t>(reinterpret_cast<int32_t *>(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(),
reinterpret_cast<int32_t *>(x1_shape_tensor->data_c()), "x1");
} else {
IndicesBoundCheck<int64_t>(reinterpret_cast<int64_t *>(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(),
reinterpret_cast<int64_t *>(x1_shape_tensor->data_c()), "x1");
}
}
abstract::TupleShapePtr SparseSliceInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
!input_args[kInputIndex0]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex0]->BuildValue()->isa<None>()) {
CheckIndices(input_args);
auto input_indices_abstract = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_indices_abstract);
auto input_indices_value_ptr = input_indices_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(input_indices_value_ptr);
auto input_indices_tensor = input_indices_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(input_indices_tensor);
auto input_indices_val = reinterpret_cast<int64_t *>(input_indices_tensor->data_c());
auto input_start_ptr = input_args[kInputIndex3]->BuildValue();
MS_EXCEPTION_IF_NULL(input_start_ptr);
auto input_start_tensor = input_start_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(input_start_tensor);
auto input_start_val = reinterpret_cast<int64_t *>(input_start_tensor->data_c());
auto input_size_ptr = input_args[kInputIndex4]->BuildValue();
MS_EXCEPTION_IF_NULL(input_size_ptr);
auto input_size_tensor = input_size_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(input_size_tensor);
auto input_size_val = reinterpret_cast<int64_t *>(input_size_tensor->data_c());
int64_t count = 0;
int64_t size_left = input_size_val[0];
int64_t size_right = input_size_val[1];
int64_t low = input_start_val[0] + size_left;
int64_t high = input_start_val[1] + size_right;
for (size_t i = 0; i < input_indices_tensor->DataSize(); i=i+2) {
if ((input_indices_val[i] >= input_start_val[0] && input_indices_val[i] < low)
&& (input_indices_val[i + 1] >= input_start_val[1] && input_indices_val[i + 1] < high)) {
count = count + 1;
}
}
std::vector<int64_t> output_indices_shape = {count, 2};
abstract::ShapePtr output_indices_shape_list =
std::make_shared<abstract::Shape>(output_indices_shape, output_indices_shape, output_indices_shape);
std::vector<int64_t> output_values_shape = {count};
abstract::ShapePtr output_values_shape_list =
std::make_shared<abstract::Shape>(output_values_shape, output_values_shape, output_values_shape);
//std::vector<int64_t> output_size_shape = {size_left, size_right};
std::vector<int64_t> output_size_shape = {2};
abstract::ShapePtr output_size_shape_list =
std::make_shared<abstract::Shape>(output_size_shape, output_size_shape, output_size_shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
output_indices_shape_list, output_values_shape_list, output_size_shape_list});
} else {
std::vector<int64_t> output_indices_shape = {abstract::Shape::SHP_ANY, 2};
abstract::ShapePtr output_indices_shape_list =
std::make_shared<abstract::Shape>(output_indices_shape, output_indices_shape, output_indices_shape);
std::vector<int64_t> output_values_shape = {abstract::Shape::SHP_ANY};
abstract::ShapePtr output_values_shape_list =
std::make_shared<abstract::Shape>(output_values_shape, output_values_shape, output_values_shape);
std::vector<int64_t> output_size_shape = {abstract::Shape::SHP_ANY, abstract::Shape::SHP_ANY};
abstract::ShapePtr output_size_shape_list=
std::make_shared<abstract::Shape>(output_size_shape, output_size_shape, output_size_shape);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{output_indices_shape_list, output_values_shape_list, output_size_shape_list});
}
}
TuplePtr SparseSliceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", input_args[0]->BuildType());
(void)types.emplace("shape", input_args[2]->BuildType());
(void)types.emplace("start", input_args[3]->BuildType());
(void)types.emplace("size", input_args[4]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt64, kInt32}, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("values", input_args[kInputIndex1]->BuildType(),
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
op_name);
auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
std::map<std::string, TypePtr> args = {{"values", input_args[kInputIndex1]->BuildType()}};
auto output_values_type = CheckAndConvertUtils::CheckTensorTypeSame(
args, {kInt8, kInt16, kInt32, kInt64, kUInt8, kFloat32, kFloat64}, op_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{kInt64, output_values_type, kInt64});
}
}
AbstractBasePtr SparseSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
CheckInputTensor(input_args);
auto infer_type = SparseSliceInferType(primitive, input_args);
auto infer_shape = SparseSliceInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSlice, prim::kPrimSparseSlice, SparseSliceInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SLICE_H_
#define MINDSPORE_CORE_OPS_SPARSE_SLICE_H_
#include <vector>
#include <set>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSlice = "SparseSlice";
/// \brief Slices a SparseTensor based on the "start" and "size".
/// Refer to Python API @ref mindspore.ops.SparseSlice for more details.
class MIND_API SparseSlice : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSlice);
/// \brief Constructor.
SparseSlice() : BaseOperator(kNameSparseSlice) {
InitIOName(
{"indices", "values", "shape", "start", "size"},
{"y_indices", "y_values", "y_shape"});
}
void Init() {}
};
AbstractBasePtr SparseSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SLICE_H_

View File

@ -72,7 +72,8 @@ class SparseDenseCwiseAdd(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SparseDenseCwiseAdd."""
self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
self.init_prim_io_names(
inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
class SparseDenseCwiseMul(Primitive):
@ -123,7 +124,8 @@ class SparseDenseCwiseMul(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SparseDenseCwiseMul."""
self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
self.init_prim_io_names(
inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
class SparseDenseCwiseDiv(Primitive):
@ -174,7 +176,62 @@ class SparseDenseCwiseDiv(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SparseDenseCwiseDiv."""
self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
self.init_prim_io_names(
inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
class SparseSlice(Primitive):
r"""
Slices a SparseTensor based on the "start" and "size".
Inputs:
- **indices** (Tensor) - A 2D Tensor of type int64. The indices of the SparseTensor.
Support int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
- **values** (Tensor) - A 1D Tensor, represents the value corresponding to the position in the `indices`.
The shape should be :math:`(n,)`.
- **shape** (Tensor) - A 1D Tensor of type int64 which specifies the shape of sparsetensor,
should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
- **start** (Tensor) - A 1D Tensor of type int64, represents the start of the slice.
- **size** (Tensor) - A 1D Tensor of type int64, represents the size of the slice.
Outputs:
A `SparseTensor` objects resulting from splicing.
- *y_indices: A Tensor of type int64.
- *y_values: A Tensor. Has the same type as "values".
- *y_shape: A Tensor of type int64.
Raises:
TypeError: If the dtype of `indices`, `shape`, `start`, `size` are not int64.
ValueError: If `indices` is not 2-D tensor.
ValueError: If `values`, `start`, `shape` , `size` is not a 1-D tensor.
ValueError: If the number of `indices` is not corresponding to the number of `values`.
ValueError: If the index of `indices` is out of the bounds of `shape`.
Supported Platforms:
Examples:
>>> indices = Tensor([[0, 1], [1, 2], [1, 3], [2, 2]], dtype=ms.int64)
>>> values = Tensor([1, 2, 3, 4])
>>> shape = Tensor([3, 4], dtype=ms.int64)
>>> start = Tensor([0, 1], dtype=ms.int64)
>>> size = Tensor([2, 3], dtype=ms.int64)
>>> sparseslice = ops.SparseSlice()
>>> output = sparseslice(indices, values, shape, start, size)
>>> print(output[0])
[[0, 0]
[1, 1]
[1, 2]]
>>> print(output[1])
[1, 2, 3]
>>> print(output[2])
[2, 3]
"""
@prim_attr_register
def __init__(self):
"""Initialize SparseSlice."""
self.init_prim_io_names(inputs=['indices', 'values', 'shape', 'start', 'size'],
outputs=['y_indices', 'y_values', 'y_shape'])
class SparseToDense(PrimitiveWithInfer):
@ -214,11 +271,14 @@ class SparseToDense(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize SparseToDense."""
self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
self.init_prim_io_names(
inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
def __infer__(self, indices, values, sparse_shape):
validator.check_tensor_dtype_valid('indices', indices['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid('values', values['dtype'], mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('indices', indices['dtype'], [
mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid(
'values', values['dtype'], mstype.number_type + (mstype.bool_,), self.name)
indices_shape = indices['shape']
if len(indices_shape) != 2:
raise ValueError(f"For '{self.name}', the 'indices' must be a 2-D tensor, "
@ -294,7 +354,8 @@ class SparseToDenseV2(Primitive):
self.add_prim_attr("max_length", 1000000)
self.validate_indices = validate_indices
self.add_prim_attr("validate_indices", self.validate_indices)
self.init_prim_io_names(inputs=['indices', 'output_shape', 'values', 'default_value'], outputs=['output'])
self.init_prim_io_names(
inputs=['indices', 'output_shape', 'values', 'default_value'], outputs=['output'])
class SparseTensorDenseAdd(Primitive):
@ -341,7 +402,8 @@ class SparseTensorDenseAdd(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SparseTensorDenseAdd."""
self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
self.init_prim_io_names(
inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2'], outputs=['y'])
class SparseTensorDenseMatmul(Primitive):
@ -408,10 +470,13 @@ class SparseTensorDenseMatmul(Primitive):
self.set_const_input_indexes([2])
def __infer__(self, indices, values, sparse_shape, dense):
validator.check_tensor_dtype_valid('indices', indices['dtype'], [mstype.int32, mstype.int64], self.name)
valid_types = (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64)
validator.check_tensor_dtype_valid('indices', indices['dtype'], [
mstype.int32, mstype.int64], self.name)
valid_types = (mstype.float16, mstype.float32,
mstype.float64, mstype.int32, mstype.int64)
args = {'values': values['dtype'], 'dense': dense['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(
args, valid_types, self.name)
indices_shape = indices['shape']
if len(indices_shape) != 2 or indices_shape[1] != 2:
raise ValueError(f"For '{self.name}', the 'indices' must be a 2-D tensor and "
@ -421,7 +486,8 @@ class SparseTensorDenseMatmul(Primitive):
raise ValueError(f"For '{self.name}', the 'values' must be a 1-D tensor and "
f"the first dimension length must be equal to the first dimension length of 'indices', "
f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}.")
a_shape = sparse_shape['value'][::-1] if self.adjoint_st else sparse_shape['value']
a_shape = sparse_shape['value'][::-
1] if self.adjoint_st else sparse_shape['value']
b_shape = dense['shape'][::-1] if self.adjoint_dt else dense['shape']
for i in a_shape:
if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
@ -634,9 +700,12 @@ class DenseToDenseSetOperation(Primitive):
@prim_attr_register
def __init__(self, set_operation="a-b", validate_indices=True):
"""Initialize DenseToDenseSetOperation."""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y_indices', 'y_values', 'y_shape'])
validator.check_value_type("set_operation", set_operation, [str], self.name)
validator.check_value_type("validate_indices", validate_indices, [bool], self.name)
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=[
'y_indices', 'y_values', 'y_shape'])
validator.check_value_type(
"set_operation", set_operation, [str], self.name)
validator.check_value_type(
"validate_indices", validate_indices, [bool], self.name)
class Sspaddmm(Primitive):
@ -1078,7 +1147,8 @@ class SparseAdd(Primitive):
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=["x1_indices", "x1_values", "x1_shape", "x2_indices", "x2_values", "x2_shape", "thresh"],
inputs=["x1_indices", "x1_values", "x1_shape",
"x2_indices", "x2_values", "x2_shape", "thresh"],
outputs=["sum_indices", "sum_values", "sum_shape"])
@ -1132,14 +1202,13 @@ class SparseMatrixSoftmax(Primitive):
2.36882806e-01, 6.43914223e-01, 2.68941432e-01, 7.31058598e-01]))
"""
@prim_attr_register
def __init__(self, dtype):
'''Initialize for SparseMatrixSoftmax'''
if not isinstance(dtype, (type(mstype.float32), type(mstype.single), type(mstype.float64),
type(mstype.double))):
raise TypeError("Only float32 and float64 type data are supported, but got {}".format(dtype))
raise TypeError(
"Only float32 and float64 type data are supported, but got {}".format(dtype))
self.add_prim_attr("dtype", dtype)
self.init_prim_io_names(inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers',
'x_col_indices', 'x_values'],
@ -1201,7 +1270,8 @@ class CSRSparseMatrixToDense(Primitive):
def __init__(self):
"""Initialize CSRSparseMatrixToDense"""
self.init_prim_io_names(
inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers', 'x_col_indices', 'x_values'],
inputs=['x_dense_shape', 'x_batch_pointers',
'x_row_pointers', 'x_col_indices', 'x_values'],
outputs=['y'])
@ -1539,8 +1609,10 @@ class SparseMatrixSparseMatMul(Primitive):
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False):
"""Initialize SparseMatrixSparseMatMul"""
validator.check_value_type("transpose_a", transpose_a, [bool], self.name)
validator.check_value_type("transpose_b", transpose_b, [bool], self.name)
validator.check_value_type(
"transpose_a", transpose_a, [bool], self.name)
validator.check_value_type(
"transpose_b", transpose_b, [bool], self.name)
validator.check_value_type("adjoint_a", adjoint_b, [bool], self.name)
validator.check_value_type("adjoint_b", adjoint_b, [bool], self.name)
self.init_prim_io_names(
@ -1622,12 +1694,16 @@ class SparseMatrixMatMul(Primitive):
def __init__(self, transpose_x1=False, transpose_x2=False, adjoint_x1=False, adjoint_x2=False,
transpose_output=False, conjugate_output=False):
"""Initialize SparseMatrixMatMul"""
validator.check_value_type("transpose_x1", transpose_x1, [bool], self.name)
validator.check_value_type("transpose_x2", transpose_x2, [bool], self.name)
validator.check_value_type(
"transpose_x1", transpose_x1, [bool], self.name)
validator.check_value_type(
"transpose_x2", transpose_x2, [bool], self.name)
validator.check_value_type("adjoint_x1", adjoint_x1, [bool], self.name)
validator.check_value_type("adjoint_x2", adjoint_x2, [bool], self.name)
validator.check_value_type("transpose_output", transpose_output, [bool], self.name)
validator.check_value_type("conjugate_output", conjugate_output, [bool], self.name)
validator.check_value_type(
"transpose_output", transpose_output, [bool], self.name)
validator.check_value_type(
"conjugate_output", conjugate_output, [bool], self.name)
self.init_prim_io_names(inputs=['x1_dense_shape', 'x1_batch_pointers', 'x1_row_pointers',
'x1_col_indices', 'x1_values', 'x2_dense'], outputs=['y_dense'])
@ -1883,4 +1959,5 @@ class SparseReshape(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SparseReshape."""
self.init_prim_io_names(inputs=['indices', 'shape', 'new_shape'], outputs=['y_indices', 'y_shape'])
self.init_prim_io_names(inputs=['indices', 'shape', 'new_shape'], outputs=[
'y_indices', 'y_shape'])