fix ctcloss tupel input check

This commit is contained in:
baihuawei 2023-02-08 15:13:40 +08:00
parent 11787a7d12
commit e22b3e1861
2 changed files with 28 additions and 5 deletions

View File

@ -29,6 +29,21 @@ int64_t CTCLossV2::get_blank() const { return GetValue<int64_t>(GetAttr(kAttrBla
std::string CTCLossV2::get_reduction() const { return GetValue<std::string>(GetAttr(kAttrReduction)); }
bool CTCLossV2::get_zero_infinity() const { return GetValue<bool>(GetAttr(kAttrZeroInfinity)); }
namespace {
void CheckInputLengthType(const std::string &arg_name, const AbstractBasePtr &input_arg,
const std::set<TypePtr> &valid_type, const std::string &prim_name) {
if (input_arg->isa<abstract::AbstractTensor>()) {
(void)CheckAndConvertUtils::CheckTypeValid(arg_name, input_arg->BuildType(), valid_type, prim_name);
} else if (input_arg->isa<abstract::AbstractTuple>()) {
auto elements = input_arg->cast<abstract::AbstractTuplePtr>()->elements();
for (size_t i = 0; i < elements.size(); ++i) {
(void)CheckAndConvertUtils::CheckSubClass(arg_name, elements[i]->BuildType(), valid_type, prim_name);
}
} else {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input " << input_arg->type_name()
<< " must be a tuple or a tensor with all Int elements, but got " << input_arg->ToString()
<< ".";
}
}
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr size_t kLenLogProbs = 3;
@ -42,6 +57,14 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
auto input_lengths_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex2]->BuildShape())[kShape];
auto target_lengths_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex3]->BuildShape())[kShape];
if (input_args[kIndex2]->isa<abstract::AbstractTuple>()) {
auto size = input_args[kIndex2]->cast<abstract::AbstractTuplePtr>()->elements().size();
input_lengths_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{SizeToLong(size)})->shape();
}
if (input_args[kIndex3]->isa<abstract::AbstractTuple>()) {
auto size = input_args[kIndex3]->cast<abstract::AbstractTuplePtr>()->elements().size();
target_lengths_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{SizeToLong(size)})->shape();
}
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape) || IsDynamicRank(input_lengths_shape) ||
IsDynamicRank(target_lengths_shape)) {
@ -88,10 +111,9 @@ TuplePtr CTCLossV2InferType(const PrimitivePtr &primitive, const std::vector<Abs
auto type = CheckAndConvertUtils::CheckTypeValid("log_probs", input_args[kInputIndex0]->BuildType(),
{kFloat32, kFloat64}, name);
(void)CheckAndConvertUtils::CheckTypeValid("targets", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64}, name);
(void)CheckAndConvertUtils::CheckTypeValid("input_lengths", input_args[kInputIndex2]->BuildType(), {kInt32, kInt64},
name);
(void)CheckAndConvertUtils::CheckTypeValid("target_lengths", input_args[kInputIndex3]->BuildType(), {kInt32, kInt64},
name);
CheckInputLengthType("input_lengths", input_args[kInputIndex2], {kInt32, kInt64}, name);
CheckInputLengthType("target_lengths", input_args[kInputIndex3], {kInt32, kInt64}, name);
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
}
} // namespace

View File

@ -481,10 +481,11 @@ class Im2Col(Primitive):
self.ksizes = (ksizes, ksizes) if isinstance(ksizes, int) else ksizes
self.strides = (strides, strides) if isinstance(strides, int) else strides
self.dilations = (dilations, dilations) if isinstance(dilations, int) else dilations
self.pads = pads
if isinstance(pads, (list, tuple)):
if len(pads) == 2:
self.pads = (pads[0], pads[0], pads[1], pads[1])
self.pads = (pads, pads, pads, pads) if isinstance(pads, int) else pads
self.pads = (pads, pads, pads, pads) if isinstance(pads, int) else self.pads
validator.check("ksizes size", len(self.ksizes), "", [1, 2], Rel.IN, self.name)
validator.check_positive_int_sequence(self.ksizes, "ksizes", self.name)