forked from mindspore-Ecosystem/mindspore
!4933 Improve the Outputs Quality and Change Saving Frequency
Merge pull request !4933 from huangxinjing/wide-deep-checkpoint
This commit is contained in:
commit
1afb8749b6
|
@ -49,7 +49,7 @@ class LossCallBack(Callback):
|
||||||
cb_params.net_outputs[1].asnumpy()
|
cb_params.net_outputs[1].asnumpy()
|
||||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||||
cur_num = cb_params.cur_step_num
|
cur_num = cb_params.cur_step_num
|
||||||
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
|
print("Status:", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
|
||||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||||
loss_file = open(self.config.loss_file_name, "a+")
|
loss_file = open(self.config.loss_file_name, "a+")
|
||||||
loss_file.write(
|
loss_file.write(
|
||||||
|
@ -91,6 +91,6 @@ class EvalCallBack(Callback):
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
eval_time = int(end_time - start_time)
|
eval_time = int(end_time - start_time)
|
||||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
out_str = "{}=====EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time)
|
out_str = "{}:EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time)
|
||||||
print(out_str)
|
print(out_str)
|
||||||
add_write(self.eval_file_name, out_str)
|
add_write(self.eval_file_name, out_str)
|
||||||
|
|
|
@ -22,7 +22,7 @@ def argparse_init():
|
||||||
parser = argparse.ArgumentParser(description='WideDeep')
|
parser = argparse.ArgumentParser(description='WideDeep')
|
||||||
|
|
||||||
parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data.
|
parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data.
|
||||||
parser.add_argument("--epochs", type=int, default=200) # The number of epochs used to train.
|
parser.add_argument("--epochs", type=int, default=8) # The number of epochs used to train.
|
||||||
parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation
|
parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation
|
||||||
parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation.
|
parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation.
|
||||||
parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP
|
parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP
|
||||||
|
|
|
@ -148,6 +148,5 @@ class AUCMetric(Metric):
|
||||||
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
||||||
|
|
||||||
MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12)
|
MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12)
|
||||||
print("=====" * 20 + " auc_metric end ")
|
print("Eval result:" + " auc: {}, map: {}".format(auc, MAP))
|
||||||
print("=====" * 20 + " auc: {}, map: {}".format(auc, MAP))
|
|
||||||
return auc
|
return auc
|
||||||
|
|
|
@ -89,13 +89,15 @@ def train_and_eval(config):
|
||||||
|
|
||||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
||||||
callback = LossCallBack(config)
|
callback = LossCallBack(config)
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
|
# Only save the last checkpoint at the last epoch. For saving epochs at each epoch, please
|
||||||
|
# set save_checkpoint_steps=ds_train.get_dataset_size()
|
||||||
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*config.epochs,
|
||||||
keep_checkpoint_max=10)
|
keep_checkpoint_max=10)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||||
directory=config.ckpt_path, config=ckptconfig)
|
directory=config.ckpt_path, config=ckptconfig)
|
||||||
|
|
||||||
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
|
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
|
||||||
callback, ckpoint_cb])
|
callback, ckpoint_cb], sink_size=ds_train.get_dataset_size())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -94,14 +94,16 @@ def train_and_eval(config):
|
||||||
|
|
||||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
||||||
callback = LossCallBack(config)
|
callback = LossCallBack(config)
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
|
# Only save the last checkpoint at the last epoch. For saving epochs at each epoch, please
|
||||||
|
# set save_checkpoint_steps=ds_train.get_dataset_size()
|
||||||
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*config.epochs,
|
||||||
keep_checkpoint_max=10)
|
keep_checkpoint_max=10)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||||
directory=config.ckpt_path, config=ckptconfig)
|
directory=config.ckpt_path, config=ckptconfig)
|
||||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||||
if int(get_rank()) == 0:
|
if int(get_rank()) == 0:
|
||||||
callback_list.append(ckpoint_cb)
|
callback_list.append(ckpoint_cb)
|
||||||
model.train(epochs, ds_train, callbacks=callback_list)
|
model.train(epochs, ds_train, callbacks=callback_list, sink_size=ds_train.get_dataset_size())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue