forked from mindspore-Ecosystem/mindspore
!46297 [assistant][ops] Fix RaggedTensorToTensor
Merge pull request !46297 from 夏芳冰/RaggedTensorToTensor
This commit is contained in:
commit
d870c9090c
|
@ -37,13 +37,59 @@ BaseShapePtr RaggedTensorToTensorInferShape(const PrimitivePtr &primitive,
|
|||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("dimension of 'default_value'", SizeToLong(default_value_shape.size()), kLessThan,
|
||||
SizeToLong(values_shape.size()), prim_name);
|
||||
|
||||
auto shape_arg = input_args[kInputIndex0];
|
||||
MS_EXCEPTION_IF_NULL(shape_arg);
|
||||
auto output_shape = GetShapeValue(primitive, shape_arg);
|
||||
auto row_partition_tensors_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
primitive->AddAttr("num_row_partition_tensors", MakeValue(SizeToLong(row_partition_tensors_shape.size())));
|
||||
auto values_rank = values_shape.size();
|
||||
auto output_shape_rank = output_shape.size();
|
||||
auto tensors = input_args[kInputIndex3]->isa<abstract::AbstractTuple>()
|
||||
? input_args[kInputIndex3]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[kInputIndex3]->cast<abstract::AbstractListPtr>()->elements();
|
||||
auto tensors_size = tensors.size();
|
||||
const auto &row_partition_types_ptr = primitive->GetAttr("row_partition_types");
|
||||
MS_EXCEPTION_IF_NULL(row_partition_types_ptr);
|
||||
const auto &row_partition_types = GetValue<std::vector<std::string>>(row_partition_types_ptr);
|
||||
auto types_size = row_partition_types.size();
|
||||
if (tensors_size != types_size) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the number of row_partition_tensors must be equal to the "
|
||||
<< "number of row_partition_types: " << types_size << ", but got " << tensors_size << ".";
|
||||
}
|
||||
if (row_partition_types[0] == "FIRST_DIM_SIZE") {
|
||||
auto tensor0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[0]->BuildShape())[kShape];
|
||||
auto tensor0_dim = tensor0_shape.size();
|
||||
CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[0](for 'FIRST_DIM_SIZE')",
|
||||
SizeToLong(tensor0_dim), kEqual, 0, prim_name);
|
||||
if (types_size - 1 + values_rank != output_shape_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', row partition size plus 'values' rank should be equal to 'shape' rank: "
|
||||
<< output_shape.size() << ", but got row partition size: " << (types_size - 1)
|
||||
<< ", 'values' rank: " << values_rank << ".";
|
||||
}
|
||||
} else if (row_partition_types[0] == "ROW_SPLITS") {
|
||||
auto tensor0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[0]->BuildShape())[kShape];
|
||||
auto tensor0_dim = tensor0_shape.size();
|
||||
CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[0](for 'ROW_SPLITS')",
|
||||
SizeToLong(tensor0_dim), kEqual, 1, prim_name);
|
||||
if (types_size + values_rank != output_shape_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', row partition size plus 'values' rank should be equal to 'shape' rank: "
|
||||
<< output_shape.size() << ", but got row partition size: " << types_size
|
||||
<< ", 'values' rank: " << values_rank << ".";
|
||||
}
|
||||
} else if (row_partition_types[0] == "VALUE_ROWIDS") {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', cannot handle 'VALUE_ROWIDS' in row_partition_types[0].";
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', row_partition_types only support 'FIRST_DIM_SIZE', "
|
||||
<< "'VALUE_ROWIDS' and 'ROW_SPLITS', but got unknown string: " << row_partition_types[0]
|
||||
<< ".";
|
||||
}
|
||||
for (size_t i = 1; i < types_size; i++) {
|
||||
auto tensori_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tensors[i]->BuildShape())[kShape];
|
||||
auto tensori_dim = tensori_shape.size();
|
||||
CheckAndConvertUtils::CheckInteger("dimension of row_partition_tensors[" + std::to_string(i) + "]",
|
||||
SizeToLong(tensori_dim), kEqual, 1, prim_name);
|
||||
}
|
||||
primitive->AddAttr("num_row_partition_tensors", MakeValue(SizeToLong(tensors_size)));
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -2520,11 +2520,12 @@ class RaggedTensorToTensor(Primitive):
|
|||
Inputs:
|
||||
- **shape** (Tensor) - A 1-D `Tensor`. Must be one of the following types: `int64`, `int32`.
|
||||
The desired shape of the output tensor.
|
||||
- **values** (Tensor) - A 1-D `Tensor` representing the values of the ragged tensor.
|
||||
- **values** (Tensor) - A 1-D or higher `Tensor` representing the values of the ragged tensor.
|
||||
- **default_value** (Tensor) - A `Tensor` representing the default value of the ragged tensor.
|
||||
Must have the same type as `values` and less dimension than `values`.
|
||||
- **row_partition_tensors** (list(Tensor)) - A list of at least 1 `Tensor` objects with the same
|
||||
type in: `int64`, `int32`.
|
||||
type in: `int64`, `int32`. The row partition tensor is 0-D, 1-D, 1-D, when the row partition type is
|
||||
"FIRST_DIM_SIZE", "VALUE_ROWIDS", "ROW_SPLITS" respectively.
|
||||
|
||||
Outputs:
|
||||
A `Tensor`. Has the same type as `values` and the shape is `shape`.
|
||||
|
@ -2533,20 +2534,17 @@ class RaggedTensorToTensor(Primitive):
|
|||
TypeError: If the type of `shape`, `values` or `default_value` is not Tensor.
|
||||
ValueError: If the dimension of `shape` or `values` is not 1.
|
||||
ValueError: If the dimension of `default_value` is more than `values`.
|
||||
RuntimeError: If the order of `row_partition_tensors` is not support
|
||||
ValueError: If the order or value of `row_partition_types` is not support.
|
||||
RuntimeError: If the value of `row_partition_tensors` is not in ascending order
|
||||
when the `row_partition_types` is "ROW_SPLITS".
|
||||
RuntimeError: If value rowid is not less than first dim size
|
||||
when the `row_partition_types` is "FIRST_DIM_SIZE", "VALUE_ROWIDS".
|
||||
RuntimeError: If the order of `row_partition_types` is not support.
|
||||
RuntimeError: If the value of `row_partition_types` is not support.
|
||||
RuntimeError: If row partition size plus `values` rank is not equal to `shape` rank.
|
||||
ValueError: If row partition size plus `values` rank is not equal to `shape` rank.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.common import dtype as mstype
|
||||
>>> from mindspore.common.tensor import Tensor
|
||||
>>> from mindspore.ops.operations.sparse_ops import RaggedTensorToTensor
|
||||
>>> shape = Tensor([4, 4], mstype.int32)
|
||||
>>> values = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], mstype.int64)
|
||||
|
@ -2567,9 +2565,38 @@ class RaggedTensorToTensor(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, row_partition_types):
|
||||
"""Initialize RaggedTensorToTensor"""
|
||||
validator.check_value_type("row_partition_types", row_partition_types, [list], self.name)
|
||||
self.init_prim_io_names(inputs=['shape', 'values', 'default_value', 'row_partition_tensors'],
|
||||
outputs=['result'])
|
||||
validator.check_value_type("row_partition_types", row_partition_types, [list], self.name)
|
||||
|
||||
if not row_partition_types:
|
||||
raise ValueError(f"For {self.name}, row_partition_types cannot be empty.")
|
||||
|
||||
for i, item in enumerate(row_partition_types):
|
||||
validator.check_value_type(f"row_partition_types[{i}]", item, [str], self.name)
|
||||
|
||||
valid_values = ("ROW_SPLITS", "FIRST_DIM_SIZE", "VALUE_ROWIDS")
|
||||
if not set(row_partition_types).issubset(valid_values):
|
||||
diff = tuple(set(row_partition_types).difference(valid_values))
|
||||
raise ValueError(
|
||||
f"For {self.name}, row_partition_types only support {valid_values}, "
|
||||
f"but got {diff if len(diff) > 1 else repr(diff[0])}.")
|
||||
|
||||
first_element = valid_values[:2]
|
||||
if row_partition_types[0] not in first_element:
|
||||
raise ValueError(
|
||||
f"For {self.name}, the first element of row_partition_types must be in {first_element}, "
|
||||
f"but got '{row_partition_types[0]}'.")
|
||||
|
||||
if row_partition_types[0] == "FIRST_DIM_SIZE":
|
||||
if set(row_partition_types[1:]) != {"VALUE_ROWIDS"}:
|
||||
raise ValueError(
|
||||
f"For {self.name}, 'VALUE_ROWIDS' must be preceded by 'FIRST_DIM_SIZE' in row_partition_types.")
|
||||
else:
|
||||
if set(row_partition_types) != {"ROW_SPLITS"}:
|
||||
raise ValueError(
|
||||
f"For {self.name}, the each element of row_partition_types must be 'ROW_SPLITS' "
|
||||
f"when row_splits tensor.")
|
||||
|
||||
|
||||
class SparseCross(Primitive):
|
||||
|
|
Loading…
Reference in New Issue