forked from mindspore-Ecosystem/mindspore
!49000 fix ROIAlignGrad when using dynamic shape
Merge pull request !49000 from JoeyLin/master
This commit is contained in:
commit
610b9c826d
|
@ -36,8 +36,7 @@ class ROIAlignGradInfer : public abstract::OpInferBase {
|
|||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
constexpr size_t kInputNum = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("the number of inputs", input_args.size(), kEqual, kInputNum, op_name);
|
||||
std::vector<int64_t> output_shape;
|
||||
auto feature_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto rois_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
|
@ -51,11 +50,31 @@ class ROIAlignGradInfer : public abstract::OpInferBase {
|
|||
(void)CheckAndConvertUtils::CheckInteger("rank of rois shape", SizeToLong(rois_shape.size()), kEqual,
|
||||
kROIGradRoisShapeSize, op_name);
|
||||
}
|
||||
|
||||
auto input_shape = input_args[kInputIndex2];
|
||||
ShapeVector out_shape = GetShapeValue(primitive, input_shape);
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
constexpr size_t kInputNum = 3;
|
||||
if (input_args.size() == kInputNum) {
|
||||
auto input_shape = input_args[kInputIndex2];
|
||||
output_shape = GetShapeValue(primitive, input_shape);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
} else if (input_args.size() == kInputNum - 1) {
|
||||
auto input_shape = primitive->GetAttr("xdiff_shape");
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
auto input_shape_tuple = input_shape->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_tuple);
|
||||
auto input_tuple = input_shape_tuple->value();
|
||||
(void)std::transform(input_tuple.begin(), input_tuple.end(), std::back_inserter(output_shape),
|
||||
[&op_name, &input_shape](const ValuePtr &size_value) -> int64_t {
|
||||
if (!size_value->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "For primitive[" << op_name << "], the 'shape'"
|
||||
<< " must be a tuple with all Int elements, but got " << input_shape->ToString();
|
||||
}
|
||||
return GetValue<int64_t>(size_value);
|
||||
});
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << op_name << "], the 'input num'"
|
||||
<< " must be 2 or 3, but got " << input_args.size();
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
|
|
Loading…
Reference in New Issue