fix bug of bert_thor

This commit is contained in:
wangmin 2020-09-01 20:59:15 +08:00
parent 36977394cf
commit d24ce34c36
5 changed files with 16 additions and 7 deletions

View File

@ -201,7 +201,7 @@ step: 3000 Accuracy: [0.71377236]
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss |1.5654222 |
| Speed | 269ms/step8pcs |
| Speed | 275ms/step8pcs |
| Total time | 14 mins |
| Parameters (M) | 330 |
| Checkpoint for Fine tuning | 4.5G(.ckpt file) |

View File

@ -155,10 +155,11 @@ def MLM_eval():
res = net.eval(dataset, dataset_sink_mode=False)
print("==============================================================")
for _, v in res.items():
print("Accuracy is: ")
print(v)
print("Accuracy is: ", v)
print("==============================================================")
if __name__ == "__main__":
DEVICE_ID = 1
os.environ['DEVICE_ID'] = str(DEVICE_ID)
MLM_eval()

View File

@ -26,7 +26,6 @@ from src.config import cfg
from src.dataset import create_bert_dataset
from src.lr_generator import get_bert_lr, get_bert_damping
from src.model_thor import Model
from src.thor_for_bert_arg import THOR
from src.utils import LossCallBack, BertLearningRate
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
@ -66,10 +65,15 @@ def run_pretrain():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args()
if args_opt.distribute == "true":
from src.thor_for_bert_arg import THOR
else:
from src.thor_for_bert import THOR
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
device_id=args_opt.device_id, save_graphs=False)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")
context.set_context(max_call_depth=3000)
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend':

View File

@ -231,16 +231,17 @@ class EmbeddingPostprocessor(nn.Cell):
frequency=frequency)
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.layernorm = nn.LayerNorm((embedding_size,))
self.add = P.TensorAdd()
def construct(self, token_type_ids, word_embeddings):
"""construct of EmbeddingPostprocessor"""
output = word_embeddings
if self.use_token_type:
token_type_embeddings, _ = self.token_type_embedding(token_type_ids)
output += token_type_embeddings
output = self.add(output, token_type_embeddings)
if not self.use_relative_positions:
position_embeddings, _ = self.full_position_embedding(self.position_ids)
output += position_embeddings
output = self.add(output, position_embeddings)
output = self.layernorm(output)
output = self.dropout(output)
return output

View File

@ -101,6 +101,8 @@ class FusedLayerNorm(Cell):
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
self.mul = P.Mul()
self.add = P.TensorAdd()
def construct(self, input_x):
"""construct of FusedLayerNorm"""
@ -112,7 +114,8 @@ class FusedLayerNorm(Cell):
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
y = self.mul(output, self.gamma)
y = self.add(y, self.beta)
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y