forked from mindspore-Ecosystem/mindspore
!3085 Adapt TBE ops ApplyCenteredRMSPropD.
Merge pull request !3085 from liuxiao93/adapt-ApplyCenteredRmsProp
This commit is contained in:
commit
dcacb6c918
|
@ -81,6 +81,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
|
||||
{"apply_add_sign", "apply_add_sign_d"},
|
||||
{"apply_power_sign", "apply_power_sign_d"},
|
||||
{"apply_centered_rms_prop", "apply_centered_rms_prop_d"},
|
||||
{"transpose", "transpose_d"},
|
||||
{"fill", "fill_d"},
|
||||
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
||||
|
|
|
@ -409,7 +409,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
||||
{string(kNameAtan2), ADPT_DESC(Atan2)},
|
||||
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
|
||||
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)},
|
||||
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)},
|
||||
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
|
||||
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
|
||||
{string(kNameRange), ADPT_DESC(RangeD)},
|
||||
|
|
|
@ -1284,12 +1284,13 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())},
|
|||
ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// ApplyCenteredRMSProp
|
||||
INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
||||
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
||||
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
||||
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
|
||||
// ApplyCenteredRMSPropD
|
||||
INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
||||
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
||||
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
||||
ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyCenteredRMSPropD) = {
|
||||
{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}};
|
||||
|
||||
// L2Loss
|
||||
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
|
||||
|
|
|
@ -486,8 +486,8 @@ DECLARE_OP_USE_OUTPUT(Atan2)
|
|||
DECLARE_OP_ADAPTER(ApplyRMSPropD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
||||
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
|
||||
DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD)
|
||||
DECLARE_OP_ADAPTER(L2Loss)
|
||||
DECLARE_OP_USE_OUTPUT(L2Loss)
|
||||
DECLARE_OP_ADAPTER(CTCLoss)
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyCenteredRMSProp op"""
|
||||
"""ApplyCenteredRMSPropD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_centered_rms_prop.so") \
|
||||
.binfile_name("apply_centered_rms_prop_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_centered_rms_prop") \
|
||||
.kernel_name("apply_centered_rms_prop_d") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "mg", False, "required", "all") \
|
||||
|
@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
|
|||
.input(7, "epsilon", False, "required", "all") \
|
||||
.input(8, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "mg", False, "required", "all") \
|
||||
.output(2, "ms", False, "required", "all") \
|
||||
.output(3, "mom", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_centered_rms_prop_op_info)
|
||||
def _apply_centered_rms_prop_tbe():
|
||||
"""ApplyCenteredRMSProp TBE register"""
|
||||
"""ApplyCenteredRMSPropD TBE register"""
|
||||
return
|
||||
|
|
|
@ -1962,6 +1962,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.is_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
|
||||
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
|
||||
|
@ -1969,6 +1970,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
|||
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||
if self.is_ascend:
|
||||
return var_shape, mean_gradient_shape, mean_square_shape, moment_shape
|
||||
return var_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
|
||||
|
@ -1982,6 +1985,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
|||
validator.check_type_same(args_rho, valid_types, self.name)
|
||||
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
|
||||
if self.is_ascend:
|
||||
return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype
|
||||
return var_dtype
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue