diff --git a/example/Bert_NEZHA/config.py b/example/Bert_NEZHA/config.py
deleted file mode 100644
index 2f3b22fe500..00000000000
--- a/example/Bert_NEZHA/config.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""
-network config setting, will be used in main.py
-"""
-
-from easydict import EasyDict as edict
-import mindspore.common.dtype as mstype
-from mindspore.model_zoo.Bert_NEZHA import BertConfig
-bert_cfg = edict({
-    'epoch_size': 10,
-    'num_warmup_steps': 0,
-    'start_learning_rate': 1e-4,
-    'end_learning_rate': 1,
-    'decay_steps': 1000,
-    'power': 10.0,
-    'save_checkpoint_steps': 2000,
-    'keep_checkpoint_max': 10,
-    'checkpoint_prefix': "checkpoint_bert",
-    'DATA_DIR' = "/your/path/examples.tfrecord"
-    'SCHEMA_DIR' = "/your/path/datasetSchema.json"
-    'bert_config': BertConfig(
-                       batch_size=16,
-                       seq_length=128,
-                       vocab_size=21136,
-                       hidden_size=1024,
-                       num_hidden_layers=24,
-                       num_attention_heads=16,
-                       intermediate_size=4096,
-                       hidden_act="gelu",
-                       hidden_dropout_prob=0.0,
-                       attention_probs_dropout_prob=0.0,
-                       max_position_embeddings=512,
-                       type_vocab_size=2,
-                       initializer_range=0.02,
-                       use_relative_positions=True,
-                       input_mask_from_dataset=True,
-                       token_type_ids_from_dataset=True,
-                       dtype=mstype.float32,
-                       compute_type=mstype.float16, 
-                   )
-})
diff --git a/example/Bert_NEZHA_cnwiki/config.py b/example/Bert_NEZHA_cnwiki/config.py
new file mode 100644
index 00000000000..a704d9a2642
--- /dev/null
+++ b/example/Bert_NEZHA_cnwiki/config.py
@@ -0,0 +1,57 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""
+network config setting, will be used in train.py
+"""
+
+from easydict import EasyDict as edict
+import mindspore.common.dtype as mstype
+from mindspore.model_zoo.Bert_NEZHA import BertConfig
+bert_train_cfg = edict({
+    'epoch_size': 10,
+    'num_warmup_steps': 0,
+    'start_learning_rate': 1e-4,
+    'end_learning_rate': 0.0,
+    'decay_steps': 1000,
+    'power': 10.0,
+    'save_checkpoint_steps': 2000,
+    'keep_checkpoint_max': 10,
+    'checkpoint_prefix': "checkpoint_bert",
+    # please add your own dataset path
+    'DATA_DIR': "/your/path/examples.tfrecord",
+    # please add your own dataset schema path
+    'SCHEMA_DIR': "/your/path/datasetSchema.json"
+})
+bert_net_cfg = BertConfig(
+    batch_size=16,
+    seq_length=128,
+    vocab_size=21136,
+    hidden_size=1024,
+    num_hidden_layers=24,
+    num_attention_heads=16,
+    intermediate_size=4096,
+    hidden_act="gelu",
+    hidden_dropout_prob=0.0,
+    attention_probs_dropout_prob=0.0,
+    max_position_embeddings=512,
+    type_vocab_size=2,
+    initializer_range=0.02,
+    use_relative_positions=True,
+    input_mask_from_dataset=True,
+    token_type_ids_from_dataset=True,
+    dtype=mstype.float32,
+    compute_type=mstype.float16,
+)
diff --git a/example/Bert_NEZHA/main.py b/example/Bert_NEZHA_cnwiki/train.py
similarity index 57%
rename from example/Bert_NEZHA/main.py
rename to example/Bert_NEZHA_cnwiki/train.py
index a5500f25a97..87f425e21c7 100644
--- a/example/Bert_NEZHA/main.py
+++ b/example/Bert_NEZHA_cnwiki/train.py
@@ -14,7 +14,8 @@
 # ============================================================================
 
 """
-NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
+NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
+model currently based on BERT developed by Huawei.
 1. Prepare data
 Following the data preparation as in BERT, run command as below to get dataset for training:
     python ./create_pretraining_data.py \
@@ -28,35 +29,29 @@ Following the data preparation as in BERT, run command as below to get dataset f
       --random_seed=12345 \
       --dupe_factor=5
 2. Pretrain
-First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py.
+First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
 """
 
 import os
-import pytest
 import numpy as np
-from numpy import allclose
-from config import bert_cfg as cfg
-import mindspore.common.dtype as mstype
+from config import bert_train_cfg, bert_net_cfg
 import mindspore.dataset.engine.datasets as de
 import mindspore._c_dataengine as deMap
 from mindspore import context
 from mindspore.common.tensor import Tensor
 from mindspore.train.model import Model
-from mindspore.train.callback import Callback
-from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
+from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
 from mindspore.nn.optim import Lamb
-from mindspore import log as logger
 _current_dir = os.path.dirname(os.path.realpath(__file__))
-DATA_DIR = [cfg.DATA_DIR]
-SCHEMA_DIR = cfg.SCHEMA_DIR
 
-def me_de_train_dataset(batch_size):
-    """test me de train dataset"""
+def create_train_dataset(batch_size):
+    """create train dataset"""
     # apply repeat operations
-    repeat_count = cfg.epoch_size
-    ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
-                                                               "next_sentence_labels", "masked_lm_positions",
-                                                               "masked_lm_ids", "masked_lm_weights"])
+    repeat_count = bert_train_cfg.epoch_size
+    ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
+                           columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
+                                         "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
     type_cast_op = deMap.TypeCastOp("int32")
     ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
     ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
@@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size):
     ds = ds.repeat(repeat_count)
     return ds
 
-
 def weight_variable(shape):
     """weight variable"""
     np.random.seed(1)
     ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
     return Tensor(ones)
 
-
-class ModelCallback(Callback):
-    def __init__(self):
-        super(ModelCallback, self).__init__()
-        self.loss_list = []
-
-    def step_end(self, run_context):
-        cb_params = run_context.original_args()
-        self.loss_list.append(cb_params.net_outputs.asnumpy()[0])
-        logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
-
-def test_bert_tdt():
-    """test bert tdt"""
+def train_bert():
+    """train bert"""
     context.set_context(mode=context.GRAPH_MODE)
     context.set_context(device_target="Ascend")
     context.set_context(enable_task_sink=True)
     context.set_context(enable_loop_sink=True)
     context.set_context(enable_mem_reuse=True)
-    parallel_callback = ModelCallback()
-    ds = me_de_train_dataset(cfg.bert_config.batch_size)
-    config = cfg.bert_config
-    netwithloss = BertNetworkWithLoss(config, True)
-    optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate,
-                     end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False)
+    ds = create_train_dataset(bert_net_cfg.batch_size)
+    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
+    optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
+                     start_learning_rate=bert_train_cfg.start_learning_rate,
+                     end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
+                     warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
     netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
     netwithgrads.set_train(True)
     model = Model(netwithgrads)
-    config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
-    ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck)
-    model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False)
+    config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
+                                 keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
+    ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
+    model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
 
 if __name__ == '__main__':
-    test_bert_tdt()
+    train_bert()