!3085 Adapt TBE ops ApplyCenteredRMSPropD.
Merge pull request !3085 from liuxiao93/adapt-ApplyCenteredRmsProp
This commit is contained in:
commit
dcacb6c918
|
@ -81,6 +81,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
|
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
|
||||||
{"apply_add_sign", "apply_add_sign_d"},
|
{"apply_add_sign", "apply_add_sign_d"},
|
||||||
{"apply_power_sign", "apply_power_sign_d"},
|
{"apply_power_sign", "apply_power_sign_d"},
|
||||||
|
{"apply_centered_rms_prop", "apply_centered_rms_prop_d"},
|
||||||
{"transpose", "transpose_d"},
|
{"transpose", "transpose_d"},
|
||||||
{"fill", "fill_d"},
|
{"fill", "fill_d"},
|
||||||
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
||||||
|
|
|
@ -409,7 +409,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
||||||
{string(kNameAtan2), ADPT_DESC(Atan2)},
|
{string(kNameAtan2), ADPT_DESC(Atan2)},
|
||||||
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
|
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
|
||||||
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)},
|
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)},
|
||||||
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
|
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
|
||||||
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
|
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
|
||||||
{string(kNameRange), ADPT_DESC(RangeD)},
|
{string(kNameRange), ADPT_DESC(RangeD)},
|
||||||
|
|
|
@ -1284,12 +1284,13 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())},
|
||||||
ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
|
||||||
|
|
||||||
// ApplyCenteredRMSProp
|
// ApplyCenteredRMSPropD
|
||||||
INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
||||||
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
||||||
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
||||||
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(ApplyCenteredRMSPropD) = {
|
||||||
|
{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}};
|
||||||
|
|
||||||
// L2Loss
|
// L2Loss
|
||||||
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
|
||||||
|
|
|
@ -486,8 +486,8 @@ DECLARE_OP_USE_OUTPUT(Atan2)
|
||||||
DECLARE_OP_ADAPTER(ApplyRMSPropD)
|
DECLARE_OP_ADAPTER(ApplyRMSPropD)
|
||||||
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
||||||
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
|
DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
|
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD)
|
||||||
DECLARE_OP_ADAPTER(L2Loss)
|
DECLARE_OP_ADAPTER(L2Loss)
|
||||||
DECLARE_OP_USE_OUTPUT(L2Loss)
|
DECLARE_OP_USE_OUTPUT(L2Loss)
|
||||||
DECLARE_OP_ADAPTER(CTCLoss)
|
DECLARE_OP_ADAPTER(CTCLoss)
|
||||||
|
|
|
@ -13,15 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""ApplyCenteredRMSProp op"""
|
"""ApplyCenteredRMSPropD op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
|
apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("apply_centered_rms_prop.so") \
|
.binfile_name("apply_centered_rms_prop_d.so") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("apply_centered_rms_prop") \
|
.kernel_name("apply_centered_rms_prop_d") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.input(0, "var", False, "required", "all") \
|
.input(0, "var", False, "required", "all") \
|
||||||
.input(1, "mg", False, "required", "all") \
|
.input(1, "mg", False, "required", "all") \
|
||||||
|
@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
|
||||||
.input(7, "epsilon", False, "required", "all") \
|
.input(7, "epsilon", False, "required", "all") \
|
||||||
.input(8, "grad", False, "required", "all") \
|
.input(8, "grad", False, "required", "all") \
|
||||||
.output(0, "var", False, "required", "all") \
|
.output(0, "var", False, "required", "all") \
|
||||||
|
.output(1, "mg", False, "required", "all") \
|
||||||
|
.output(2, "ms", False, "required", "all") \
|
||||||
|
.output(3, "mom", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_5HD, DataType.F16_5HD) \
|
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_FracZ, DataType.F16_FracZ) \
|
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||||
|
DataType.F16_FracZ) \
|
||||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||||
|
DataType.F16_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
|
DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_5HD, DataType.F32_5HD) \
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD) \
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_FracZ, DataType.F32_FracZ) \
|
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||||
|
DataType.F32_FracZ) \
|
||||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||||
|
DataType.F32_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default, DataType.F32_Default) \
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(apply_centered_rms_prop_op_info)
|
@op_info_register(apply_centered_rms_prop_op_info)
|
||||||
def _apply_centered_rms_prop_tbe():
|
def _apply_centered_rms_prop_tbe():
|
||||||
"""ApplyCenteredRMSProp TBE register"""
|
"""ApplyCenteredRMSPropD TBE register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -1962,6 +1962,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, use_locking=False):
|
def __init__(self, use_locking=False):
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
|
self.is_ascend = context.get_context("device_target") == "Ascend"
|
||||||
|
|
||||||
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
|
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
|
||||||
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
|
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
|
||||||
|
@ -1969,6 +1970,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||||
|
if self.is_ascend:
|
||||||
|
return var_shape, mean_gradient_shape, mean_square_shape, moment_shape
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
|
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
|
||||||
|
@ -1982,6 +1985,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
validator.check_type_same(args_rho, valid_types, self.name)
|
validator.check_type_same(args_rho, valid_types, self.name)
|
||||||
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
|
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
|
||||||
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
|
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
|
||||||
|
if self.is_ascend:
|
||||||
|
return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype
|
||||||
return var_dtype
|
return var_dtype
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue