[fix][assistant][I48OB7] Modify the logic error in infershape

This commit is contained in:
bsx 2022-01-11 20:00:31 +08:00
parent 1cb8896ed8
commit 890fd9b395
1 changed files with 0 additions and 12 deletions

View File

@ -38,10 +38,6 @@ abstract::ShapePtr ApplyCenteredRMSPropInferShape(const PrimitivePtr &primitive,
auto ms_shape_ptr = ms_shape->cast<abstract::ShapePtr>();
auto mom_shape_ptr = mom_shape->cast<abstract::ShapePtr>();
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
auto lr_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->GetShapeTrack());
auto decay_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->GetShapeTrack());
auto momentum_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[7]->GetShapeTrack());
auto epsilon_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[8]->GetShapeTrack());
// var and mg must have the same shape when is not dynamic
if (!var_shape_ptr->IsDynamic() && !mg_shape_ptr->IsDynamic()) {
if (*var_shape != *mg_shape) {
@ -70,14 +66,6 @@ abstract::ShapePtr ApplyCenteredRMSPropInferShape(const PrimitivePtr &primitive,
<< " are not consistent with var shape " << var_shape->ToString();
}
}
const int64_t kShapeSize = 0;
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape_map[kShape].size(), kEqual, kShapeSize, op_name);
(void)CheckAndConvertUtils::CheckInteger("decay_shape size", decay_shape_map[kShape].size(), kEqual, kShapeSize,
op_name);
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape_map[kShape].size(), kEqual, kShapeSize,
op_name);
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape size", epsilon_shape_map[kShape].size(), kEqual, kShapeSize,
op_name);
auto shape_element = var_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;