!3085 Adapt TBE ops ApplyCenteredRMSPropD.

Merge pull request !3085 from liuxiao93/adapt-ApplyCenteredRmsProp
This commit is contained in:
mindspore-ci-bot 2020-07-16 14:33:56 +08:00 committed by Gitee
commit dcacb6c918
6 changed files with 39 additions and 21 deletions

View File

@ -81,6 +81,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
{"apply_add_sign", "apply_add_sign_d"},
{"apply_power_sign", "apply_power_sign_d"},
{"apply_centered_rms_prop", "apply_centered_rms_prop_d"},
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},

View File

@ -409,7 +409,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
{string(kNameAtan2), ADPT_DESC(Atan2)},
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)},
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)},
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
{string(kNameRange), ADPT_DESC(RangeD)},

View File

@ -1284,12 +1284,13 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())},
ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
// ApplyCenteredRMSProp
INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
// ApplyCenteredRMSPropD
INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyCenteredRMSPropD) = {
{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}};
// L2Loss
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};

View File

@ -486,8 +486,8 @@ DECLARE_OP_USE_OUTPUT(Atan2)
DECLARE_OP_ADAPTER(ApplyRMSPropD)
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD)
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD)
DECLARE_OP_ADAPTER(L2Loss)
DECLARE_OP_USE_OUTPUT(L2Loss)
DECLARE_OP_ADAPTER(CTCLoss)

View File

@ -13,15 +13,15 @@
# limitations under the License.
# ============================================================================
"""ApplyCenteredRMSProp op"""
"""ApplyCenteredRMSPropD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_centered_rms_prop.so") \
.binfile_name("apply_centered_rms_prop_d.so") \
.compute_cost(10) \
.kernel_name("apply_centered_rms_prop") \
.kernel_name("apply_centered_rms_prop_d") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "mg", False, "required", "all") \
@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
.input(7, "epsilon", False, "required", "all") \
.input(8, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "mg", False, "required", "all") \
.output(2, "ms", False, "required", "all") \
.output(3, "mom", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_5HD, DataType.F16_5HD) \
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_FracZ, DataType.F16_FracZ) \
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.get_op_info()
@op_info_register(apply_centered_rms_prop_op_info)
def _apply_centered_rms_prop_tbe():
"""ApplyCenteredRMSProp TBE register"""
"""ApplyCenteredRMSPropD TBE register"""
return

View File

@ -1962,6 +1962,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.is_ascend = context.get_context("device_target") == "Ascend"
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
@ -1969,6 +1970,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
if self.is_ascend:
return var_shape, mean_gradient_shape, mean_square_shape, moment_shape
return var_shape
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
@ -1982,6 +1985,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
validator.check_type_same(args_rho, valid_types, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
if self.is_ascend:
return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype
return var_dtype