forked from mindspore-Ecosystem/mindspore
!2199 Fix the condition when activation name is 0
Merge pull request !2199 from Simson/push-to-opensource
This commit is contained in:
commit
4e3b3711be
|
@ -549,9 +549,9 @@ def get_activation(name):
|
||||||
Examples:
|
Examples:
|
||||||
>>> sigmoid = nn.get_activation('sigmoid')
|
>>> sigmoid = nn.get_activation('sigmoid')
|
||||||
"""
|
"""
|
||||||
if not name:
|
if name is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if name not in _activation:
|
if name not in _activation:
|
||||||
raise KeyError("Unknown activation type")
|
raise KeyError(f"Unknown activation type '{name}'")
|
||||||
return _activation[name]()
|
return _activation[name]()
|
||||||
|
|
|
@ -76,7 +76,7 @@ class Net(nn.Cell):
|
||||||
weight='normal',
|
weight='normal',
|
||||||
bias='zeros',
|
bias='zeros',
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation=''):
|
activation=None):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.dense = nn.Dense(input_channels,
|
self.dense = nn.Dense(input_channels,
|
||||||
output_channels,
|
output_channels,
|
||||||
|
|
|
@ -46,10 +46,6 @@ def test_activation_param():
|
||||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||||
|
|
||||||
|
|
||||||
def test_activation_empty():
|
|
||||||
assert nn.get_activation('') is None
|
|
||||||
|
|
||||||
|
|
||||||
# test softmax
|
# test softmax
|
||||||
def test_softmax_axis():
|
def test_softmax_axis():
|
||||||
layer = nn.Softmax(1)
|
layer = nn.Softmax(1)
|
||||||
|
|
Loading…
Reference in New Issue