!4933 Improve the Outputs Quality and Change Saving Frequency

Merge pull request !4933 from huangxinjing/wide-deep-checkpoint
This commit is contained in:
mindspore-ci-bot 2020-08-26 00:16:12 +08:00 committed by Gitee
commit 1afb8749b6
5 changed files with 12 additions and 9 deletions

View File

@ -49,7 +49,7 @@ class LossCallBack(Callback):
cb_params.net_outputs[1].asnumpy() cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num cur_num = cb_params.cur_step_num
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True) print("Status:", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
if self._per_print_times != 0 and cur_num % self._per_print_times == 0: if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
loss_file = open(self.config.loss_file_name, "a+") loss_file = open(self.config.loss_file_name, "a+")
loss_file.write( loss_file.write(
@ -91,6 +91,6 @@ class EvalCallBack(Callback):
end_time = time.time() end_time = time.time()
eval_time = int(end_time - start_time) eval_time = int(end_time - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = "{}=====EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time) out_str = "{}:EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time)
print(out_str) print(out_str)
add_write(self.eval_file_name, out_str) add_write(self.eval_file_name, out_str)

View File

@ -22,7 +22,7 @@ def argparse_init():
parser = argparse.ArgumentParser(description='WideDeep') parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data. parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data.
parser.add_argument("--epochs", type=int, default=200) # The number of epochs used to train. parser.add_argument("--epochs", type=int, default=8) # The number of epochs used to train.
parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation
parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation. parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation.
parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP

View File

@ -148,6 +148,5 @@ class AUCMetric(Metric):
auc = roc_auc_score(self.true_labels, self.pred_probs) auc = roc_auc_score(self.true_labels, self.pred_probs)
MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12) MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12)
print("=====" * 20 + " auc_metric end ") print("Eval result:" + " auc: {}, map: {}".format(auc, MAP))
print("=====" * 20 + " auc: {}, map: {}".format(auc, MAP))
return auc return auc

View File

@ -89,13 +89,15 @@ def train_and_eval(config):
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config) callback = LossCallBack(config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), # Only save the last checkpoint at the last epoch. For saving epochs at each epoch, please
# set save_checkpoint_steps=ds_train.get_dataset_size()
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*config.epochs,
keep_checkpoint_max=10) keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig) directory=config.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
callback, ckpoint_cb]) callback, ckpoint_cb], sink_size=ds_train.get_dataset_size())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -94,14 +94,16 @@ def train_and_eval(config):
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config) callback = LossCallBack(config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), # Only save the last checkpoint at the last epoch. For saving epochs at each epoch, please
# set save_checkpoint_steps=ds_train.get_dataset_size()
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*config.epochs,
keep_checkpoint_max=10) keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig) directory=config.ckpt_path, config=ckptconfig)
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if int(get_rank()) == 0: if int(get_rank()) == 0:
callback_list.append(ckpoint_cb) callback_list.append(ckpoint_cb)
model.train(epochs, ds_train, callbacks=callback_list) model.train(epochs, ds_train, callbacks=callback_list, sink_size=ds_train.get_dataset_size())
if __name__ == "__main__": if __name__ == "__main__":