!46297 [assistant][ops] Fix RaggedTensorToTensor

Merge pull request !46297 from 夏芳冰/RaggedTensorToTensor
This commit is contained in:
i-robot 2022-12-04 03:47:35 +00:00 committed by Gitee
commit d870c9090c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 86 additions and 13 deletions

View File

@ -37,13 +37,59 @@ BaseShapePtr RaggedTensorToTensorInferShape(const PrimitivePtr &primitive,
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("dimension of 'default_value'", SizeToLong(default_value_shape.size()), kLessThan,
SizeToLong(values_shape.size()), prim_name);
auto shape_arg = input_args[kInputIndex0];
MS_EXCEPTION_IF_NULL(shape_arg);
auto output_shape = GetShapeValue(primitive, shape_arg);
auto row_partition_tensors_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
primitive->AddAttr("num_row_partition_tensors", MakeValue(SizeToLong(row_partition_tensors_shape.size())));
auto values_rank = values_shape.size();
auto output_shape_rank = output_shape.size();
auto tensors = input_args[kInputIndex3]->isa<abstract::AbstractTuple>()
? input_args[kInputIndex3]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[kInputIndex3]->cast<abstract::AbstractListPtr>()->elements();
auto tensors_size = tensors.size();
const auto &row_partition_types_ptr = primitive->GetAttr("row_partition_types");
MS_EXCEPTION_IF_NULL(row_partition_types_ptr);
const auto &row_partition_types = GetValue<std::vector<std::string>>(row_partition_types_ptr);
auto types_size = row_partition_types.size();
if (tensors_size != types_size) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the number of row_partition_tensors must be equal to the "
<< "number of row_partition_types: " << types_size << ", but got " << tensors_size << ".";
}
if (row_partition_types[0] == "FIRST_DIM_SIZE") {
auto tensor0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[0]->BuildShape())[kShape];
auto tensor0_dim = tensor0_shape.size();
CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[0](for 'FIRST_DIM_SIZE')",
SizeToLong(tensor0_dim), kEqual, 0, prim_name);
if (types_size - 1 + values_rank != output_shape_rank) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', row partition size plus 'values' rank should be equal to 'shape' rank: "
<< output_shape.size() << ", but got row partition size: " << (types_size - 1)
<< ", 'values' rank: " << values_rank << ".";
}
} else if (row_partition_types[0] == "ROW_SPLITS") {
auto tensor0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[0]->BuildShape())[kShape];
auto tensor0_dim = tensor0_shape.size();
CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[0](for 'ROW_SPLITS')",
SizeToLong(tensor0_dim), kEqual, 1, prim_name);
if (types_size + values_rank != output_shape_rank) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', row partition size plus 'values' rank should be equal to 'shape' rank: "
<< output_shape.size() << ", but got row partition size: " << types_size
<< ", 'values' rank: " << values_rank << ".";
}
} else if (row_partition_types[0] == "VALUE_ROWIDS") {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', cannot handle 'VALUE_ROWIDS' in row_partition_types[0].";
} else {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', row_partition_types only support 'FIRST_DIM_SIZE', "
<< "'VALUE_ROWIDS' and 'ROW_SPLITS', but got unknown string: " << row_partition_types[0]
<< ".";
}
for (size_t i = 1; i < types_size; i++) {
auto tensori_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[i]->BuildShape())[kShape];
auto tensori_dim = tensori_shape.size();
CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[" + std::to_string(i) + "]",
SizeToLong(tensori_dim), kEqual, 1, prim_name);
}
primitive->AddAttr("num_row_partition_tensors", MakeValue(SizeToLong(tensors_size)));
return std::make_shared<abstract::Shape>(output_shape);
}

View File

@ -2520,11 +2520,12 @@ class RaggedTensorToTensor(Primitive):
Inputs:
- **shape** (Tensor) - A 1-D `Tensor`. Must be one of the following types: `int64`, `int32`.
The desired shape of the output tensor.
- **values** (Tensor) - A 1-D `Tensor` representing the values of the ragged tensor.
- **values** (Tensor) - A 1-D or higher `Tensor` representing the values of the ragged tensor.
- **default_value** (Tensor) - A `Tensor` representing the default value of the ragged tensor.
Must have the same type as `values` and less dimension than `values`.
- **row_partition_tensors** (list(Tensor)) - A list of at least 1 `Tensor` objects with the same
type in: `int64`, `int32`.
type in: `int64`, `int32`. The row partition tensor is 0-D, 1-D, 1-D, when the row partition type is
"FIRST_DIM_SIZE", "VALUE_ROWIDS", "ROW_SPLITS" respectively.
Outputs:
A `Tensor`. Has the same type as `values` and the shape is `shape`.
@ -2533,20 +2534,17 @@ class RaggedTensorToTensor(Primitive):
TypeError: If the type of `shape`, `values` or `default_value` is not Tensor.
ValueError: If the dimension of `shape` or `values` is not 1.
ValueError: If the dimension of `default_value` is more than `values`.
RuntimeError: If the order of `row_partition_tensors` is not support
ValueError: If the order or value of `row_partition_types` is not support.
RuntimeError: If the value of `row_partition_tensors` is not in ascending order
when the `row_partition_types` is "ROW_SPLITS".
RuntimeError: If value rowid is not less than first dim size
when the `row_partition_types` is "FIRST_DIM_SIZE", "VALUE_ROWIDS".
RuntimeError: If the order of `row_partition_types` is not support.
RuntimeError: If the value of `row_partition_types` is not support.
RuntimeError: If row partition size plus `values` rank is not equal to `shape` rank.
ValueError: If row partition size plus `values` rank is not equal to `shape` rank.
Supported Platforms:
``CPU``
Examples:
>>> from mindspore.common import dtype as mstype
>>> from mindspore.common.tensor import Tensor
>>> from mindspore.ops.operations.sparse_ops import RaggedTensorToTensor
>>> shape = Tensor([4, 4], mstype.int32)
>>> values = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], mstype.int64)
@ -2567,9 +2565,38 @@ class RaggedTensorToTensor(Primitive):
@prim_attr_register
def __init__(self, row_partition_types):
"""Initialize RaggedTensorToTensor"""
validator.check_value_type("row_partition_types", row_partition_types, [list], self.name)
self.init_prim_io_names(inputs=['shape', 'values', 'default_value', 'row_partition_tensors'],
outputs=['result'])
validator.check_value_type("row_partition_types", row_partition_types, [list], self.name)
if not row_partition_types:
raise ValueError(f"For {self.name}, row_partition_types cannot be empty.")
for i, item in enumerate(row_partition_types):
validator.check_value_type(f"row_partition_types[{i}]", item, [str], self.name)
valid_values = ("ROW_SPLITS", "FIRST_DIM_SIZE", "VALUE_ROWIDS")
if not set(row_partition_types).issubset(valid_values):
diff = tuple(set(row_partition_types).difference(valid_values))
raise ValueError(
f"For {self.name}, row_partition_types only support {valid_values}, "
f"but got {diff if len(diff) > 1 else repr(diff[0])}.")
first_element = valid_values[:2]
if row_partition_types[0] not in first_element:
raise ValueError(
f"For {self.name}, the first element of row_partition_types must be in {first_element}, "
f"but got '{row_partition_types[0]}'.")
if row_partition_types[0] == "FIRST_DIM_SIZE":
if set(row_partition_types[1:]) != {"VALUE_ROWIDS"}:
raise ValueError(
f"For {self.name}, 'VALUE_ROWIDS' must be preceded by 'FIRST_DIM_SIZE' in row_partition_types.")
else:
if set(row_partition_types) != {"ROW_SPLITS"}:
raise ValueError(
f"For {self.name}, the each element of row_partition_types must be 'ROW_SPLITS' "
f"when row_splits tensor.")
class SparseCross(Primitive):