!49000 fix ROIAlignGrad when using dynamic shape

Merge pull request !49000 from JoeyLin/master
This commit is contained in:
i-robot 2023-02-22 01:56:29 +00:00 committed by Gitee
commit 610b9c826d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 26 additions and 7 deletions

View File

@ -36,8 +36,7 @@ class ROIAlignGradInfer : public abstract::OpInferBase {
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
constexpr size_t kInputNum = 3;
(void)CheckAndConvertUtils::CheckInteger("the number of inputs", input_args.size(), kEqual, kInputNum, op_name);
std::vector<int64_t> output_shape;
auto feature_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rois_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
@ -51,11 +50,31 @@ class ROIAlignGradInfer : public abstract::OpInferBase {
(void)CheckAndConvertUtils::CheckInteger("rank of rois shape", SizeToLong(rois_shape.size()), kEqual,
kROIGradRoisShapeSize, op_name);
}
auto input_shape = input_args[kInputIndex2];
ShapeVector out_shape = GetShapeValue(primitive, input_shape);
return std::make_shared<abstract::Shape>(out_shape);
constexpr size_t kInputNum = 3;
if (input_args.size() == kInputNum) {
auto input_shape = input_args[kInputIndex2];
output_shape = GetShapeValue(primitive, input_shape);
return std::make_shared<abstract::Shape>(output_shape);
} else if (input_args.size() == kInputNum - 1) {
auto input_shape = primitive->GetAttr("xdiff_shape");
MS_EXCEPTION_IF_NULL(input_shape);
auto input_shape_tuple = input_shape->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(input_shape_tuple);
auto input_tuple = input_shape_tuple->value();
(void)std::transform(input_tuple.begin(), input_tuple.end(), std::back_inserter(output_shape),
[&op_name, &input_shape](const ValuePtr &size_value) -> int64_t {
if (!size_value->isa<Int64Imm>()) {
MS_EXCEPTION(TypeError)
<< "For primitive[" << op_name << "], the 'shape'"
<< " must be a tuple with all Int elements, but got " << input_shape->ToString();
}
return GetValue<int64_t>(size_value);
});
} else {
MS_EXCEPTION(TypeError) << "For primitive[" << op_name << "], the 'input num'"
<< " must be 2 or 3, but got " << input_args.size();
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {