uadate L1Regularizer,upadate test_l1_regularizer_op.py

This commit is contained in:
“dangjiaqi1” 2020-12-25 17:04:13 +08:00
parent 537e940d6e
commit 86bbef5146
2 changed files with 34 additions and 5 deletions

View File

@ -47,7 +47,7 @@ class L1Regularizer(Cell):
scale(regularization factor) should be a number which greater than 0
Args:
scale (int, float) l1 regularization factor which greater than 0.
scale (int, float): l1 regularization factor which greater than 0.
Raises:
ValueError: If `scale(regularization factor)` is not greater than 0.
@ -57,7 +57,8 @@ class L1Regularizer(Cell):
- **weights** (Tensor) - The input tensor
Outputs:
Tensor, which dtype is Float and shape is ()
Tensor, which dtype is higher precision data type between mindspore.float32 and weights dtype,
and Tensor shape is ()
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -77,13 +78,13 @@ class L1Regularizer(Cell):
if scale <= 0:
raise ValueError("scale should be a number which greater than 0")
if math.isinf(scale) or math.isnan(scale):
raise ValueError("scale is INF or NAN")
raise ValueError("scale can not be INF or NAN")
self.abs = P.Abs()
self.reduce_sum = P.ReduceSum()
self.scale = Tensor(scale, dtype=mstype.float32)
def construct(self, weights):
const_utils.check_valid_type(weights.dtype, mstype.number_type, 'weights')
const_utils.check_valid_type(F.dtype(weights), mstype.number_type, 'weights')
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
return l1_regularization
@ -278,7 +279,7 @@ class Dense(Cell):
if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
weight_init.shape[1] != in_channels:
raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")

View File

@ -58,3 +58,31 @@ def test_l1_regularizer08():
expect = 5.0
print("output : ", output.asnumpy())
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l1_regularizer_input_int():
scale = 0.5
net = nn.L1Regularizer(scale)
weights = 2
try:
output = net(weights)
print("output : ", output.asnumpy())
except TypeError:
assert True
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l1_regularizer_input_tuple():
scale = 0.5
net = nn.L1Regularizer(scale)
weights = (1, 2, 3, 4)
try:
output = net(weights)
print("output : ", output.asnumpy())
except TypeError:
assert True