forked from mindspore-Ecosystem/mindspore
pylint clean
This commit is contained in:
parent
c8f69f5db2
commit
18f0af0529
|
@ -167,7 +167,7 @@ class BertAttentionMask(nn.Cell):
|
||||||
|
|
||||||
super(BertAttentionMask, self).__init__()
|
super(BertAttentionMask, self).__init__()
|
||||||
self.has_attention_mask = has_attention_mask
|
self.has_attention_mask = has_attention_mask
|
||||||
self.multiply_data = Tensor([-1000.0, ], dtype=dtype)
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype)
|
||||||
self.multiply = P.Mul()
|
self.multiply = P.Mul()
|
||||||
|
|
||||||
if self.has_attention_mask:
|
if self.has_attention_mask:
|
||||||
|
@ -198,7 +198,7 @@ class BertAttentionMaskBackward(nn.Cell):
|
||||||
dtype=mstype.float32):
|
dtype=mstype.float32):
|
||||||
super(BertAttentionMaskBackward, self).__init__()
|
super(BertAttentionMaskBackward, self).__init__()
|
||||||
self.has_attention_mask = has_attention_mask
|
self.has_attention_mask = has_attention_mask
|
||||||
self.multiply_data = Tensor([-1000.0, ], dtype=dtype)
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype)
|
||||||
self.multiply = P.Mul()
|
self.multiply = P.Mul()
|
||||||
self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32))
|
self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32))
|
||||||
if self.has_attention_mask:
|
if self.has_attention_mask:
|
||||||
|
|
|
@ -136,7 +136,7 @@ def test_LSTM():
|
||||||
train_network.set_train()
|
train_network.set_train()
|
||||||
|
|
||||||
train_features = Tensor(np.ones([64, max_len]).astype(np.int32))
|
train_features = Tensor(np.ones([64, max_len]).astype(np.int32))
|
||||||
train_labels = Tensor(np.ones([64, ]).astype(np.int32)[0:64])
|
train_labels = Tensor(np.ones([64,]).astype(np.int32)[0:64])
|
||||||
losses = []
|
losses = []
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
loss = train_network(train_features, train_labels)
|
loss = train_network(train_features, train_labels)
|
||||||
|
|
|
@ -34,7 +34,7 @@ ndarr = np.ones((2, 3))
|
||||||
|
|
||||||
def test_tensor_flatten():
|
def test_tensor_flatten():
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
lst = [1, 2, 3, 4, ]
|
lst = [1, 2, 3, 4,]
|
||||||
tensor_list = ms.Tensor(lst, ms.float32)
|
tensor_list = ms.Tensor(lst, ms.float32)
|
||||||
tensor_list = tensor_list.Flatten()
|
tensor_list = tensor_list.Flatten()
|
||||||
print(tensor_list)
|
print(tensor_list)
|
||||||
|
|
Loading…
Reference in New Issue