From 7bf7ae13f9c8b4c127e492105872afcbb51165c2 Mon Sep 17 00:00:00 2001 From: xfb-666 <2293162700@qq.com> Date: Sat, 5 Nov 2022 11:04:41 +0800 Subject: [PATCH] Add RaggedTensorToTensor --- mindspore/core/ops/ragged_tensor_to_tensor.cc | 54 +++++++++++++++++-- .../mindspore/ops/operations/sparse_ops.py | 45 ++++++++++++---- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/mindspore/core/ops/ragged_tensor_to_tensor.cc b/mindspore/core/ops/ragged_tensor_to_tensor.cc index bfb8c80d799..7a45fbaf61b 100644 --- a/mindspore/core/ops/ragged_tensor_to_tensor.cc +++ b/mindspore/core/ops/ragged_tensor_to_tensor.cc @@ -37,13 +37,59 @@ BaseShapePtr RaggedTensorToTensorInferShape(const PrimitivePtr &primitive, CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; CheckAndConvertUtils::CheckInteger("dimension of 'default_value'", SizeToLong(default_value_shape.size()), kLessThan, SizeToLong(values_shape.size()), prim_name); - auto shape_arg = input_args[kInputIndex0]; MS_EXCEPTION_IF_NULL(shape_arg); auto output_shape = GetShapeValue(primitive, shape_arg); - auto row_partition_tensors_shape = - CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; - primitive->AddAttr("num_row_partition_tensors", MakeValue(SizeToLong(row_partition_tensors_shape.size()))); + auto values_rank = values_shape.size(); + auto output_shape_rank = output_shape.size(); + auto tensors = input_args[kInputIndex3]->isa() + ? input_args[kInputIndex3]->cast()->elements() + : input_args[kInputIndex3]->cast()->elements(); + auto tensors_size = tensors.size(); + const auto &row_partition_types_ptr = primitive->GetAttr("row_partition_types"); + MS_EXCEPTION_IF_NULL(row_partition_types_ptr); + const auto &row_partition_types = GetValue>(row_partition_types_ptr); + auto types_size = row_partition_types.size(); + if (tensors_size != types_size) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the number of row_partition_tensors must be equal to the " + << "number of row_partition_types: " << types_size << ", but got " << tensors_size << "."; + } + if (row_partition_types[0] == "FIRST_DIM_SIZE") { + auto tensor0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[0]->BuildShape())[kShape]; + auto tensor0_dim = tensor0_shape.size(); + CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[0](for 'FIRST_DIM_SIZE')", + SizeToLong(tensor0_dim), kEqual, 0, prim_name); + if (types_size - 1 + values_rank != output_shape_rank) { + MS_EXCEPTION(ValueError) << "For '" << prim_name + << "', row partition size plus 'values' rank should be equal to 'shape' rank: " + << output_shape.size() << ", but got row partition size: " << (types_size - 1) + << ", 'values' rank: " << values_rank << "."; + } + } else if (row_partition_types[0] == "ROW_SPLITS") { + auto tensor0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[0]->BuildShape())[kShape]; + auto tensor0_dim = tensor0_shape.size(); + CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[0](for 'ROW_SPLITS')", + SizeToLong(tensor0_dim), kEqual, 1, prim_name); + if (types_size + values_rank != output_shape_rank) { + MS_EXCEPTION(ValueError) << "For '" << prim_name + << "', row partition size plus 'values' rank should be equal to 'shape' rank: " + << output_shape.size() << ", but got row partition size: " << types_size + << ", 'values' rank: " << values_rank << "."; + } + } else if (row_partition_types[0] == "VALUE_ROWIDS") { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', cannot handle 'VALUE_ROWIDS' in row_partition_types[0]."; + } else { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', row_partition_types only support 'FIRST_DIM_SIZE', " + << "'VALUE_ROWIDS' and 'ROW_SPLITS', but got unknown string: " << row_partition_types[0] + << "."; + } + for (size_t i = 1; i < types_size; i++) { + auto tensori_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[i]->BuildShape())[kShape]; + auto tensori_dim = tensori_shape.size(); + CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[" + std::to_string(i) + "]", + SizeToLong(tensori_dim), kEqual, 1, prim_name); + } + primitive->AddAttr("num_row_partition_tensors", MakeValue(SizeToLong(tensors_size))); return std::make_shared(output_shape); } diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index 2f341b6fcc9..4d6f416d3f4 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -2508,11 +2508,12 @@ class RaggedTensorToTensor(Primitive): Inputs: - **shape** (Tensor) - A 1-D `Tensor`. Must be one of the following types: `int64`, `int32`. The desired shape of the output tensor. - - **values** (Tensor) - A 1-D `Tensor` representing the values of the ragged tensor. + - **values** (Tensor) - A 1-D or higher `Tensor` representing the values of the ragged tensor. - **default_value** (Tensor) - A `Tensor` representing the default value of the ragged tensor. Must have the same type as `values` and less dimension than `values`. - **row_partition_tensors** (list(Tensor)) - A list of at least 1 `Tensor` objects with the same - type in: `int64`, `int32`. + type in: `int64`, `int32`. The row partition tensor is 0-D, 1-D, 1-D, when the row partition type is + "FIRST_DIM_SIZE", "VALUE_ROWIDS", "ROW_SPLITS" respectively. Outputs: A `Tensor`. Has the same type as `values` and the shape is `shape`. @@ -2521,20 +2522,17 @@ class RaggedTensorToTensor(Primitive): TypeError: If the type of `shape`, `values` or `default_value` is not Tensor. ValueError: If the dimension of `shape` or `values` is not 1. ValueError: If the dimension of `default_value` is more than `values`. - RuntimeError: If the order of `row_partition_tensors` is not support + ValueError: If the order or value of `row_partition_types` is not support. + RuntimeError: If the value of `row_partition_tensors` is not in ascending order when the `row_partition_types` is "ROW_SPLITS". RuntimeError: If value rowid is not less than first dim size when the `row_partition_types` is "FIRST_DIM_SIZE", "VALUE_ROWIDS". - RuntimeError: If the order of `row_partition_types` is not support. - RuntimeError: If the value of `row_partition_types` is not support. - RuntimeError: If row partition size plus `values` rank is not equal to `shape` rank. + ValueError: If row partition size plus `values` rank is not equal to `shape` rank. Supported Platforms: ``CPU`` Examples: - >>> from mindspore.common import dtype as mstype - >>> from mindspore.common.tensor import Tensor >>> from mindspore.ops.operations.sparse_ops import RaggedTensorToTensor >>> shape = Tensor([4, 4], mstype.int32) >>> values = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], mstype.int64) @@ -2555,9 +2553,38 @@ class RaggedTensorToTensor(Primitive): @prim_attr_register def __init__(self, row_partition_types): """Initialize RaggedTensorToTensor""" - validator.check_value_type("row_partition_types", row_partition_types, [list], self.name) self.init_prim_io_names(inputs=['shape', 'values', 'default_value', 'row_partition_tensors'], outputs=['result']) + validator.check_value_type("row_partition_types", row_partition_types, [list], self.name) + + if not row_partition_types: + raise ValueError(f"For {self.name}, row_partition_types cannot be empty.") + + for i, item in enumerate(row_partition_types): + validator.check_value_type(f"row_partition_types[{i}]", item, [str], self.name) + + valid_values = ("ROW_SPLITS", "FIRST_DIM_SIZE", "VALUE_ROWIDS") + if not set(row_partition_types).issubset(valid_values): + diff = tuple(set(row_partition_types).difference(valid_values)) + raise ValueError( + f"For {self.name}, row_partition_types only support {valid_values}, " + f"but got {diff if len(diff) > 1 else repr(diff[0])}.") + + first_element = valid_values[:2] + if row_partition_types[0] not in first_element: + raise ValueError( + f"For {self.name}, the first element of row_partition_types must be in {first_element}, " + f"but got '{row_partition_types[0]}'.") + + if row_partition_types[0] == "FIRST_DIM_SIZE": + if set(row_partition_types[1:]) != {"VALUE_ROWIDS"}: + raise ValueError( + f"For {self.name}, 'VALUE_ROWIDS' must be preceded by 'FIRST_DIM_SIZE' in row_partition_types.") + else: + if set(row_partition_types) != {"ROW_SPLITS"}: + raise ValueError( + f"For {self.name}, the each element of row_partition_types must be 'ROW_SPLITS' " + f"when row_splits tensor.") class SparseCross(Primitive):