!28669 [assistant][ops]New operator implementation, include ApplyCenteredRMSProp

Merge pull request !28669 from ganqijun/ApplyCenteredRMSProp
This commit is contained in:
i-robot 2022-01-12 02:00:34 +00:00 committed by Gitee
commit 3a863c06fc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 0 additions and 12 deletions

View File

@ -38,10 +38,6 @@ abstract::ShapePtr ApplyCenteredRMSPropInferShape(const PrimitivePtr &primitive,
auto ms_shape_ptr = ms_shape->cast<abstract::ShapePtr>();
auto mom_shape_ptr = mom_shape->cast<abstract::ShapePtr>();
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
auto lr_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->GetShapeTrack());
auto decay_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->GetShapeTrack());
auto momentum_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[7]->GetShapeTrack());
auto epsilon_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[8]->GetShapeTrack());
// var and mg must have the same shape when is not dynamic
if (!var_shape_ptr->IsDynamic() && !mg_shape_ptr->IsDynamic()) {
if (*var_shape != *mg_shape) {
@ -70,14 +66,6 @@ abstract::ShapePtr ApplyCenteredRMSPropInferShape(const PrimitivePtr &primitive,
<< " are not consistent with var shape " << var_shape->ToString();
}
}
const int64_t kShapeSize = 0;
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape_map[kShape].size(), kEqual, kShapeSize, op_name);
(void)CheckAndConvertUtils::CheckInteger("decay_shape size", decay_shape_map[kShape].size(), kEqual, kShapeSize,
op_name);
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape_map[kShape].size(), kEqual, kShapeSize,
op_name);
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape size", epsilon_shape_map[kShape].size(), kEqual, kShapeSize,
op_name);
auto shape_element = var_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;