forked from mindspore-Ecosystem/mindspore
fix ctcloss tupel input check
This commit is contained in:
parent
11787a7d12
commit
e22b3e1861
|
@ -29,6 +29,21 @@ int64_t CTCLossV2::get_blank() const { return GetValue<int64_t>(GetAttr(kAttrBla
|
|||
std::string CTCLossV2::get_reduction() const { return GetValue<std::string>(GetAttr(kAttrReduction)); }
|
||||
bool CTCLossV2::get_zero_infinity() const { return GetValue<bool>(GetAttr(kAttrZeroInfinity)); }
|
||||
namespace {
|
||||
void CheckInputLengthType(const std::string &arg_name, const AbstractBasePtr &input_arg,
|
||||
const std::set<TypePtr> &valid_type, const std::string &prim_name) {
|
||||
if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||
(void)CheckAndConvertUtils::CheckTypeValid(arg_name, input_arg->BuildType(), valid_type, prim_name);
|
||||
} else if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||
auto elements = input_arg->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
(void)CheckAndConvertUtils::CheckSubClass(arg_name, elements[i]->BuildType(), valid_type, prim_name);
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input " << input_arg->type_name()
|
||||
<< " must be a tuple or a tensor with all Int elements, but got " << input_arg->ToString()
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
|
@ -42,6 +57,14 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
auto input_lengths_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex2]->BuildShape())[kShape];
|
||||
auto target_lengths_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex3]->BuildShape())[kShape];
|
||||
if (input_args[kIndex2]->isa<abstract::AbstractTuple>()) {
|
||||
auto size = input_args[kIndex2]->cast<abstract::AbstractTuplePtr>()->elements().size();
|
||||
input_lengths_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{SizeToLong(size)})->shape();
|
||||
}
|
||||
if (input_args[kIndex3]->isa<abstract::AbstractTuple>()) {
|
||||
auto size = input_args[kIndex3]->cast<abstract::AbstractTuplePtr>()->elements().size();
|
||||
target_lengths_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{SizeToLong(size)})->shape();
|
||||
}
|
||||
|
||||
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape) || IsDynamicRank(input_lengths_shape) ||
|
||||
IsDynamicRank(target_lengths_shape)) {
|
||||
|
@ -88,10 +111,9 @@ TuplePtr CTCLossV2InferType(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
auto type = CheckAndConvertUtils::CheckTypeValid("log_probs", input_args[kInputIndex0]->BuildType(),
|
||||
{kFloat32, kFloat64}, name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("targets", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64}, name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("input_lengths", input_args[kInputIndex2]->BuildType(), {kInt32, kInt64},
|
||||
name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("target_lengths", input_args[kInputIndex3]->BuildType(), {kInt32, kInt64},
|
||||
name);
|
||||
|
||||
CheckInputLengthType("input_lengths", input_args[kInputIndex2], {kInt32, kInt64}, name);
|
||||
CheckInputLengthType("target_lengths", input_args[kInputIndex3], {kInt32, kInt64}, name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -481,10 +481,11 @@ class Im2Col(Primitive):
|
|||
self.ksizes = (ksizes, ksizes) if isinstance(ksizes, int) else ksizes
|
||||
self.strides = (strides, strides) if isinstance(strides, int) else strides
|
||||
self.dilations = (dilations, dilations) if isinstance(dilations, int) else dilations
|
||||
self.pads = pads
|
||||
if isinstance(pads, (list, tuple)):
|
||||
if len(pads) == 2:
|
||||
self.pads = (pads[0], pads[0], pads[1], pads[1])
|
||||
self.pads = (pads, pads, pads, pads) if isinstance(pads, int) else pads
|
||||
self.pads = (pads, pads, pads, pads) if isinstance(pads, int) else self.pads
|
||||
|
||||
validator.check("ksizes size", len(self.ksizes), "", [1, 2], Rel.IN, self.name)
|
||||
validator.check_positive_int_sequence(self.ksizes, "ksizes", self.name)
|
||||
|
|
Loading…
Reference in New Issue