forked from mindspore-Ecosystem/mindspore
fix ctcloss tupel input check
This commit is contained in:
parent
11787a7d12
commit
e22b3e1861
|
@ -29,6 +29,21 @@ int64_t CTCLossV2::get_blank() const { return GetValue<int64_t>(GetAttr(kAttrBla
|
||||||
std::string CTCLossV2::get_reduction() const { return GetValue<std::string>(GetAttr(kAttrReduction)); }
|
std::string CTCLossV2::get_reduction() const { return GetValue<std::string>(GetAttr(kAttrReduction)); }
|
||||||
bool CTCLossV2::get_zero_infinity() const { return GetValue<bool>(GetAttr(kAttrZeroInfinity)); }
|
bool CTCLossV2::get_zero_infinity() const { return GetValue<bool>(GetAttr(kAttrZeroInfinity)); }
|
||||||
namespace {
|
namespace {
|
||||||
|
void CheckInputLengthType(const std::string &arg_name, const AbstractBasePtr &input_arg,
|
||||||
|
const std::set<TypePtr> &valid_type, const std::string &prim_name) {
|
||||||
|
if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||||
|
(void)CheckAndConvertUtils::CheckTypeValid(arg_name, input_arg->BuildType(), valid_type, prim_name);
|
||||||
|
} else if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||||
|
auto elements = input_arg->cast<abstract::AbstractTuplePtr>()->elements();
|
||||||
|
for (size_t i = 0; i < elements.size(); ++i) {
|
||||||
|
(void)CheckAndConvertUtils::CheckSubClass(arg_name, elements[i]->BuildType(), valid_type, prim_name);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input " << input_arg->type_name()
|
||||||
|
<< " must be a tuple or a tensor with all Int elements, but got " << input_arg->ToString()
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
constexpr size_t kLenLogProbs = 3;
|
constexpr size_t kLenLogProbs = 3;
|
||||||
|
@ -42,6 +57,14 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||||
auto input_lengths_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex2]->BuildShape())[kShape];
|
auto input_lengths_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex2]->BuildShape())[kShape];
|
||||||
auto target_lengths_shape =
|
auto target_lengths_shape =
|
||||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex3]->BuildShape())[kShape];
|
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex3]->BuildShape())[kShape];
|
||||||
|
if (input_args[kIndex2]->isa<abstract::AbstractTuple>()) {
|
||||||
|
auto size = input_args[kIndex2]->cast<abstract::AbstractTuplePtr>()->elements().size();
|
||||||
|
input_lengths_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{SizeToLong(size)})->shape();
|
||||||
|
}
|
||||||
|
if (input_args[kIndex3]->isa<abstract::AbstractTuple>()) {
|
||||||
|
auto size = input_args[kIndex3]->cast<abstract::AbstractTuplePtr>()->elements().size();
|
||||||
|
target_lengths_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{SizeToLong(size)})->shape();
|
||||||
|
}
|
||||||
|
|
||||||
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape) || IsDynamicRank(input_lengths_shape) ||
|
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape) || IsDynamicRank(input_lengths_shape) ||
|
||||||
IsDynamicRank(target_lengths_shape)) {
|
IsDynamicRank(target_lengths_shape)) {
|
||||||
|
@ -88,10 +111,9 @@ TuplePtr CTCLossV2InferType(const PrimitivePtr &primitive, const std::vector<Abs
|
||||||
auto type = CheckAndConvertUtils::CheckTypeValid("log_probs", input_args[kInputIndex0]->BuildType(),
|
auto type = CheckAndConvertUtils::CheckTypeValid("log_probs", input_args[kInputIndex0]->BuildType(),
|
||||||
{kFloat32, kFloat64}, name);
|
{kFloat32, kFloat64}, name);
|
||||||
(void)CheckAndConvertUtils::CheckTypeValid("targets", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64}, name);
|
(void)CheckAndConvertUtils::CheckTypeValid("targets", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64}, name);
|
||||||
(void)CheckAndConvertUtils::CheckTypeValid("input_lengths", input_args[kInputIndex2]->BuildType(), {kInt32, kInt64},
|
|
||||||
name);
|
CheckInputLengthType("input_lengths", input_args[kInputIndex2], {kInt32, kInt64}, name);
|
||||||
(void)CheckAndConvertUtils::CheckTypeValid("target_lengths", input_args[kInputIndex3]->BuildType(), {kInt32, kInt64},
|
CheckInputLengthType("target_lengths", input_args[kInputIndex3], {kInt32, kInt64}, name);
|
||||||
name);
|
|
||||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -481,10 +481,11 @@ class Im2Col(Primitive):
|
||||||
self.ksizes = (ksizes, ksizes) if isinstance(ksizes, int) else ksizes
|
self.ksizes = (ksizes, ksizes) if isinstance(ksizes, int) else ksizes
|
||||||
self.strides = (strides, strides) if isinstance(strides, int) else strides
|
self.strides = (strides, strides) if isinstance(strides, int) else strides
|
||||||
self.dilations = (dilations, dilations) if isinstance(dilations, int) else dilations
|
self.dilations = (dilations, dilations) if isinstance(dilations, int) else dilations
|
||||||
|
self.pads = pads
|
||||||
if isinstance(pads, (list, tuple)):
|
if isinstance(pads, (list, tuple)):
|
||||||
if len(pads) == 2:
|
if len(pads) == 2:
|
||||||
self.pads = (pads[0], pads[0], pads[1], pads[1])
|
self.pads = (pads[0], pads[0], pads[1], pads[1])
|
||||||
self.pads = (pads, pads, pads, pads) if isinstance(pads, int) else pads
|
self.pads = (pads, pads, pads, pads) if isinstance(pads, int) else self.pads
|
||||||
|
|
||||||
validator.check("ksizes size", len(self.ksizes), "", [1, 2], Rel.IN, self.name)
|
validator.check("ksizes size", len(self.ksizes), "", [1, 2], Rel.IN, self.name)
|
||||||
validator.check_positive_int_sequence(self.ksizes, "ksizes", self.name)
|
validator.check_positive_int_sequence(self.ksizes, "ksizes", self.name)
|
||||||
|
|
Loading…
Reference in New Issue