[fix][assistant][I48OB7] Modify the logic error in infershape
This commit is contained in:
parent
1cb8896ed8
commit
890fd9b395
|
@ -38,10 +38,6 @@ abstract::ShapePtr ApplyCenteredRMSPropInferShape(const PrimitivePtr &primitive,
|
|||
auto ms_shape_ptr = ms_shape->cast<abstract::ShapePtr>();
|
||||
auto mom_shape_ptr = mom_shape->cast<abstract::ShapePtr>();
|
||||
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
|
||||
auto lr_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->GetShapeTrack());
|
||||
auto decay_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->GetShapeTrack());
|
||||
auto momentum_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[7]->GetShapeTrack());
|
||||
auto epsilon_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[8]->GetShapeTrack());
|
||||
// var and mg must have the same shape when is not dynamic
|
||||
if (!var_shape_ptr->IsDynamic() && !mg_shape_ptr->IsDynamic()) {
|
||||
if (*var_shape != *mg_shape) {
|
||||
|
@ -70,14 +66,6 @@ abstract::ShapePtr ApplyCenteredRMSPropInferShape(const PrimitivePtr &primitive,
|
|||
<< " are not consistent with var shape " << var_shape->ToString();
|
||||
}
|
||||
}
|
||||
const int64_t kShapeSize = 0;
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape_map[kShape].size(), kEqual, kShapeSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("decay_shape size", decay_shape_map[kShape].size(), kEqual, kShapeSize,
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape_map[kShape].size(), kEqual, kShapeSize,
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape size", epsilon_shape_map[kShape].size(), kEqual, kShapeSize,
|
||||
op_name);
|
||||
auto shape_element = var_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
|
|
Loading…
Reference in New Issue