forked from mindspore-Ecosystem/mindspore
!1320 add applyrmsprop cumprod and reduceprod for vm
Merge pull request !1320 from JichenZhao/applyrms_squaresumall_cumprod_reduceprod
This commit is contained in:
commit
368007240c
|
@ -92,7 +92,10 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"l_ars_update", "lars_v2_update"},
|
||||
{"n_ms_with_mask", "nms_with_mask"},
|
||||
{"square_sum_all", "square_sum_all"},
|
||||
{"cum_sum", "cumsum_d"}};
|
||||
{"cum_sum", "cumsum_d"},
|
||||
{"apply_rms_prop", "apply_rms_prop_d"},
|
||||
{"cum_prod", "cumprod_d"},
|
||||
{"reduce_prod", "reduce_prod_d"}};
|
||||
|
||||
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
||||
if (func_name == nullptr) {
|
||||
|
|
|
@ -167,6 +167,7 @@ const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
|||
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
|
||||
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
|
||||
|
||||
// NN
|
||||
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
|
|
|
@ -173,6 +173,7 @@ extern const PrimitivePtr kPrimEqual;
|
|||
extern const PrimitivePtr kPrimLess;
|
||||
extern const PrimitivePtr kPrimLessEqual;
|
||||
extern const PrimitivePtr kPrimCumSum;
|
||||
extern const PrimitivePtr kPrimCumProd;
|
||||
|
||||
// NN
|
||||
extern const PrimitivePtr kPrimFlatten;
|
||||
|
|
|
@ -41,6 +41,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimOneHot->name(), {1});
|
||||
Register(prim::kPrimConcat->name(), {0});
|
||||
Register(prim::kPrimCumSum->name(), {1});
|
||||
Register(prim::kPrimCumProd->name(), {1});
|
||||
Register(kUnsortedSegmentProdOpName, {2});
|
||||
Register(kUnsortedSegmentMinOpName, {2});
|
||||
Register(kSimpleMeanGradOpName, {1});
|
||||
|
@ -60,7 +61,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(kResizeNearestNeighborGradOpName, {1});
|
||||
Register(kResizeNearestNeighborV2OpName, {1});
|
||||
Register(kResizeNearestNeighborV2GradOpName, {1});
|
||||
Register(kApplyRMSPropOpname, {4, 5, 6});
|
||||
Register(kApplyRMSPropOpname, {5, 6, 7});
|
||||
Register(kResizeBilinearV2OpName, {1});
|
||||
Register(kReduceProdOpName, {1});
|
||||
Register(kCumprodOpName, {1});
|
||||
|
|
|
@ -26,7 +26,7 @@ centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
|||
def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
|
||||
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
|
||||
success = F.depend(success, opt(weight, ms, mom, learning_rate, grad, decay, momentum, epsilon))
|
||||
return success
|
||||
|
||||
|
||||
|
|
|
@ -180,7 +180,6 @@ from .check_valid import _check_valid_tbe
|
|||
from .iou import _iou_tbe
|
||||
from .arg_max import _arg_max_tbe
|
||||
from .nms_with_mask import _nms_with_mask_tbe
|
||||
from .random_choice_with_mask import _random_choice_with_mask_tbe
|
||||
from .sgd import _sgd_tbe
|
||||
from .lars_update import _lars_update_tbe
|
||||
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
||||
|
@ -195,3 +194,6 @@ from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
|
|||
from .sin import _sin_tbe
|
||||
from .cos import _cos_tbe
|
||||
from .cum_sum import _cum_sum_tbe
|
||||
from .apply_rms_prop import _apply_rms_prop_tbe
|
||||
from .cumprod import _cumprop_tbe
|
||||
from .reduce_prod import _reduce_prod_tbe
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyRMSProd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_rms_prop_op_info = TBERegOp("ApplyRMSProp") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_rms_prop.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_rms_prop_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("rho", "required", "float", "all") \
|
||||
.attr("momentum", "required", "float", "all") \
|
||||
.attr("epsilon", "required", "float", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "ms", False, "required", "all") \
|
||||
.input(2, "mom", False, "required", "all") \
|
||||
.input(3, "lr", False, "required", "all") \
|
||||
.input(4, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "ms", False, "required", "all") \
|
||||
.output(2, "mom", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_rms_prop_op_info)
|
||||
def _apply_rms_prop_tbe():
|
||||
"""ApplyRMSProp TBE register"""
|
||||
return
|
|
@ -13,29 +13,30 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""RandomChoiceWithMask op"""
|
||||
"""CumProd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
random_choice_with_mask_op_info = TBERegOp("RandomChoiceWithMask") \
|
||||
cumprop_op_info = TBERegOp("CumProd") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("random_choice_with_mask.so") \
|
||||
.binfile_name("cumprod_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("random_choice_with_mask") \
|
||||
.kernel_name("cumprod_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("max_shape", "optional", "listInt", "all") \
|
||||
.attr("means", "optional", "listFloat", "all") \
|
||||
.attr("stds", "optional", "listFloat", "all") \
|
||||
.attr("wh_ratio_clip", "optional", "float", "all") \
|
||||
.input(0, "rois", False, "required", "all") \
|
||||
.input(1, "deltas", False, "required", "all") \
|
||||
.output(0, "bboxes", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.attr("axis", "optional", "int", "all") \
|
||||
.attr("exclusive", "optional", "bool", "all") \
|
||||
.attr("reverse", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(random_choice_with_mask_op_info)
|
||||
def _random_choice_with_mask_tbe():
|
||||
"""RandomChoiceWithMask TBE register"""
|
||||
@op_info_register(cumprop_op_info)
|
||||
def _cumprop_tbe():
|
||||
"""CumProd TBE register"""
|
||||
return
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceProd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reduce_prod_op_info = TBERegOp("ReduceProd") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reduce_prod_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reduce_prod_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "listInt", "all") \
|
||||
.attr("keep_dims", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_prod_op_info)
|
||||
def _reduce_prod_tbe():
|
||||
"""ReduceProd TBE register"""
|
||||
return
|
|
@ -475,6 +475,7 @@ class CumProd(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
|
||||
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
|
||||
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, axis_shape):
|
||||
return x_shape
|
||||
|
@ -2022,8 +2023,10 @@ class NMSWithMask(PrimitiveWithInfer):
|
|||
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
|
||||
if not self.is_ge:
|
||||
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name)
|
||||
else:
|
||||
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
|
||||
num = bboxes_shape[0]
|
||||
return ((num, 5), (num,), (num,))
|
||||
|
||||
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
|
||||
num = bboxes_shape[0]
|
||||
return (bboxes_shape, (num,), (num,))
|
||||
|
||||
|
@ -2171,8 +2174,8 @@ class SquareSumAll(PrimitiveWithInfer):
|
|||
- **output_y2** (Tensor) - The same type as the `input_x1`.
|
||||
|
||||
Examples:
|
||||
>>> input_x1 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32)
|
||||
>>> input_x2 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32)
|
||||
>>> input_x1 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
|
||||
>>> input_x2 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
|
||||
>>> square_sum_all = P.SquareSumAll()
|
||||
>>> square_sum_all(input_x1, input_x2)
|
||||
"""
|
||||
|
|
|
@ -1721,15 +1721,21 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
|
||||
'rho', 'momentum', 'epsilon'], outputs=['output'])
|
||||
self.is_ge = context.get_context("enable_ge")
|
||||
self.is_d = context.get_context("device_target") == "Ascend"
|
||||
|
||||
def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape,
|
||||
def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape,
|
||||
momentum_shape, epsilon_shape):
|
||||
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||
if not self.is_ge and self.is_d:
|
||||
return var_shape, var_shape, var_shape
|
||||
return var_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype,
|
||||
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
|
||||
momentum_dtype, epsilon_dtype):
|
||||
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
|
@ -1739,6 +1745,8 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
validator.check_type_same(args_decay, valid_types, self.name)
|
||||
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
|
||||
if not self.is_ge and self.is_d:
|
||||
return var_dtype, var_dtype, var_dtype
|
||||
return var_dtype
|
||||
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class NetRMSProp(nn.Cell):
|
|||
if self.use_centered:
|
||||
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
|
||||
else:
|
||||
return self.rms_opt(var, rms, mom, g, lr, decay, momentum, epsilon)
|
||||
return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon)
|
||||
|
||||
|
||||
def rmsprop_numpy(variable, gradients, mean_square, moment,
|
||||
|
|
|
@ -202,6 +202,21 @@ class ApplyFtrlNet(nn.Cell):
|
|||
out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power)
|
||||
return out
|
||||
|
||||
class ApplyRMSNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyRMSNet, self).__init__()
|
||||
self.apply_rms = P.ApplyRMSProp()
|
||||
self.lr = 0.001
|
||||
self.rho = 0.0
|
||||
self.momentum= 0.0
|
||||
self.epsilon = 1e-10
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.ms = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="ms")
|
||||
self.moment = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="moment")
|
||||
|
||||
def construct(self, grad):
|
||||
out = self.apply_rms(self.var, self.ms, self.moment, self.lr, grad, self.rho, self.momentum, self.epsilon)
|
||||
return out
|
||||
|
||||
test_case_math_ops = [
|
||||
('Neg', {
|
||||
|
@ -914,9 +929,8 @@ test_case_nn_ops = [
|
|||
'desc_bprop': [3, 3],
|
||||
'skip': ['backward']}),
|
||||
('ApplyRMSProp', {
|
||||
'block': P.ApplyRMSProp(),
|
||||
'desc_const': [0.9, 0.0, 1e-10, 0.001],
|
||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
||||
'block': ApplyRMSNet(),
|
||||
'desc_inputs': [[3, 3]],
|
||||
'desc_bprop': [3, 3],
|
||||
'skip': ['backward']}),
|
||||
('ApplyCenteredRMSProp', {
|
||||
|
|
Loading…
Reference in New Issue