forked from mindspore-Ecosystem/mindspore
uadate L1Regularizer,upadate test_l1_regularizer_op.py
This commit is contained in:
parent
537e940d6e
commit
86bbef5146
|
@ -47,7 +47,7 @@ class L1Regularizer(Cell):
|
|||
scale(regularization factor) should be a number which greater than 0
|
||||
|
||||
Args:
|
||||
scale (int, float): l1 regularization factor which greater than 0.
|
||||
scale (int, float): l1 regularization factor which greater than 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If `scale(regularization factor)` is not greater than 0.
|
||||
|
@ -57,7 +57,8 @@ class L1Regularizer(Cell):
|
|||
- **weights** (Tensor) - The input tensor
|
||||
|
||||
Outputs:
|
||||
Tensor, which dtype is Float and shape is ()
|
||||
Tensor, which dtype is higher precision data type between mindspore.float32 and weights dtype,
|
||||
and Tensor shape is ()
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -77,13 +78,13 @@ class L1Regularizer(Cell):
|
|||
if scale <= 0:
|
||||
raise ValueError("scale should be a number which greater than 0")
|
||||
if math.isinf(scale) or math.isnan(scale):
|
||||
raise ValueError("scale is INF or NAN")
|
||||
raise ValueError("scale can not be INF or NAN")
|
||||
self.abs = P.Abs()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.scale = Tensor(scale, dtype=mstype.float32)
|
||||
|
||||
def construct(self, weights):
|
||||
const_utils.check_valid_type(weights.dtype, mstype.number_type, 'weights')
|
||||
const_utils.check_valid_type(F.dtype(weights), mstype.number_type, 'weights')
|
||||
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
|
||||
return l1_regularization
|
||||
|
||||
|
@ -278,7 +279,7 @@ class Dense(Cell):
|
|||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("Weight init shape error.")
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
|
|
|
@ -58,3 +58,31 @@ def test_l1_regularizer08():
|
|||
expect = 5.0
|
||||
print("output : ", output.asnumpy())
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_l1_regularizer_input_int():
|
||||
scale = 0.5
|
||||
net = nn.L1Regularizer(scale)
|
||||
weights = 2
|
||||
try:
|
||||
output = net(weights)
|
||||
print("output : ", output.asnumpy())
|
||||
except TypeError:
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_l1_regularizer_input_tuple():
|
||||
scale = 0.5
|
||||
net = nn.L1Regularizer(scale)
|
||||
weights = (1, 2, 3, 4)
|
||||
try:
|
||||
output = net(weights)
|
||||
print("output : ", output.asnumpy())
|
||||
except TypeError:
|
||||
assert True
|
||||
|
|
Loading…
Reference in New Issue