forked from mindspore-Ecosystem/mindspore
vm for LRN and LRNGrad
This commit is contained in:
parent
5b14292f69
commit
a1e148cb4d
|
@ -107,6 +107,8 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"r_oi_align_grad", "roi_align_grad"},
|
||||
{"i_ou", "iou"},
|
||||
{"s_gd", "sgd"},
|
||||
{"l_rn", "lrn"},
|
||||
{"l_rn_grad", "lrn_grad"},
|
||||
{"l_ars_update", "lars_v2_update"},
|
||||
{"n_ms_with_mask", "nms_with_mask"},
|
||||
{"square_sum_all", "square_sum_all"},
|
||||
|
|
|
@ -721,3 +721,15 @@ def get_bprop_basic_lstm_cell(self):
|
|||
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
||||
return dxt, dht, dct_1, dw, db
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.LRN)
|
||||
def get_bprop_lrn(self):
|
||||
"""Grad definition for `LRN` operation."""
|
||||
grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = grad(dout, x, out)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -267,3 +267,5 @@ from .lin_space import _lin_space_tbe
|
|||
from .matrix_diag import _matrix_diag_tbe
|
||||
from .matrix_diag_part import _matrix_diag_part_tbe
|
||||
from .matrix_set_diag import _matrix_set_diag_tbe
|
||||
from .lrn import _lrn_tbe
|
||||
from .lrn_grad import _lrn_grad_tbe
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LRN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lrn_op_info = TBERegOp("LRN") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lrn.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lrn") \
|
||||
.partial_flag(True) \
|
||||
.attr("depth_radius", "optional", "int", "all", "5") \
|
||||
.attr("bias", "optional", "float", "all", "1.0") \
|
||||
.attr("alpha", "optional", "float", "all", "1.0") \
|
||||
.attr("beta", "optional", "float", "all", "0.5") \
|
||||
.attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lrn_op_info)
|
||||
def _lrn_tbe():
|
||||
"""LRN TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LRNGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lrn_grad_op_info = TBERegOp("LRNGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lrn_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lrn_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("depth_radius", "optional", "int", "all") \
|
||||
.attr("bias", "optional", "float", "all") \
|
||||
.attr("alpha", "optional", "float", "all") \
|
||||
.attr("beta", "optional", "float", "all") \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "y", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lrn_grad_op_info)
|
||||
def _lrn_grad_tbe():
|
||||
"""LRNGrad TBE register"""
|
||||
return
|
|
@ -68,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
|||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid,
|
||||
SigmoidCrossEntropyWithLogits,
|
||||
SmoothL1Loss, Softmax, Softplus,
|
||||
SmoothL1Loss, Softmax, Softplus, LRN,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
|
@ -316,7 +316,8 @@ __all__ = [
|
|||
"DataFormatDimMap",
|
||||
"ApproximateEqual",
|
||||
"InplaceUpdate",
|
||||
"InTopK"
|
||||
"InTopK",
|
||||
"LRN"
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -1364,3 +1364,22 @@ class InvGrad(PrimitiveWithInfer):
|
|||
validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
|
||||
validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
|
||||
return x
|
||||
|
||||
|
||||
class LRNGrad(PrimitiveWithInfer):
|
||||
"""Computes gradients for LRN operation."""
|
||||
@prim_attr_register
|
||||
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
|
||||
self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
|
||||
validator.check_value_type("depth_radius", depth_radius, [int], self.name)
|
||||
validator.check_value_type("bias", bias, [float], self.name)
|
||||
validator.check_value_type("alpha", alpha, [float], self.name)
|
||||
validator.check_value_type("beta", beta, [float], self.name)
|
||||
|
||||
def infer_dtype(self, grads, x, y):
|
||||
args = {"grads": grads, "x": x, "y": y}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name)
|
||||
return x
|
||||
|
||||
def infer_shape(self, grads, x, y):
|
||||
return x
|
||||
|
|
|
@ -4252,3 +4252,44 @@ class InTopK(PrimitiveWithInfer):
|
|||
validator.check("x2", len(x2_shape), "", 1, Rel.EQ, self.name)
|
||||
validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name)
|
||||
return x2_shape
|
||||
|
||||
|
||||
class LRN(PrimitiveWithInfer):
|
||||
r"""
|
||||
Local Response Normalization
|
||||
|
||||
Args:
|
||||
depth_radius (int): Half-width of the 1-D normalization window. Shape of 0-D.
|
||||
bias (float): An offset (usually positive to avoid dividing by 0).
|
||||
alpha (float): A scale factor, usually positive.
|
||||
beta (float): An exponent.
|
||||
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL".
|
||||
Default: "ACROSS_CHANNELS".
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A 4D Tensor with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, With shape and data type same as the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(1, 10, 4, 4)), mindspore.float32)
|
||||
>>> lrn = P.LRN()
|
||||
>>> lrn(x)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"):
|
||||
"""Init LRN"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
validator.check_value_type("depth_radius", depth_radius, [int], self.name)
|
||||
validator.check_value_type("bias", bias, [float], self.name)
|
||||
validator.check_value_type("alpha", alpha, [float], self.name)
|
||||
validator.check_value_type("beta", beta, [float], self.name)
|
||||
validator.check_value_type("norm_region", norm_region, [str], self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
|
|
@ -482,6 +482,29 @@ class PReLUGradNet(nn.Cell):
|
|||
def construct(self, dout, x, w):
|
||||
return self.prelu_grad(dout, x, w)
|
||||
|
||||
|
||||
class LRNNet(nn.Cell):
|
||||
""" LRNNet definition """
|
||||
|
||||
def __init__(self):
|
||||
super(LRNNet, self).__init__()
|
||||
self.lrn = P.LRN()
|
||||
|
||||
def construct(self, x):
|
||||
return self.lrn(x)
|
||||
|
||||
|
||||
class LRNGradNet(nn.Cell):
|
||||
""" LRNGradNet definition """
|
||||
|
||||
def __init__(self):
|
||||
super(LRNGradNet, self).__init__()
|
||||
self.lrn_grad = G.LRNGrad()
|
||||
|
||||
def construct(self, dout, x, out):
|
||||
return self.lrn_grad(dout, x, out)
|
||||
|
||||
|
||||
test_cases = [
|
||||
('SoftMaxGrad', {
|
||||
'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
|
||||
|
@ -593,6 +616,16 @@ test_cases = [
|
|||
Tensor(np.array([1, 2]).astype(np.float32))],
|
||||
'skip': ['backward']
|
||||
}),
|
||||
('LRNNet', {
|
||||
'block': LRNNet(),
|
||||
'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32))],
|
||||
}),
|
||||
('LRNGradNet', {
|
||||
'block': LRNGradNet(),
|
||||
'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32)),
|
||||
Tensor(np.ones([1, 5, 4, 4], np.float32)),
|
||||
Tensor(np.ones([1, 5, 4, 4], np.float32))],
|
||||
}),
|
||||
]
|
||||
|
||||
test_cases_for_verify_exception = [
|
||||
|
|
Loading…
Reference in New Issue