Fix batch size check

This commit is contained in:
huangxinjing 2020-12-24 11:29:53 +08:00
parent 11a3353239
commit fe89ad2c49
2 changed files with 13 additions and 6 deletions

View File

@ -121,12 +121,12 @@ def train_and_eval(config):
model = Model(train_net, eval_network=eval_net, model = Model(train_net, eval_network=eval_net,
metrics={"auc": auc_metric}) metrics={"auc": auc_metric})
eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config)
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks. # Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config, per_print_times=20) callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
keep_checkpoint_max=5, integrated_save=False) keep_checkpoint_max=5, integrated_save=False)
@ -146,10 +146,11 @@ if __name__ == "__main__":
wide_deep_config = WideDeepConfig() wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init() wide_deep_config.argparse_init()
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target=wide_deep_config.device_target, save_graphs=True) device_target=wide_deep_config.device_target)
context.set_context(variable_memory_max_size="24GB") context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True) context.set_context(enable_sparse=True)
init() init()
context.set_context(save_graphs_path='./graphs_of_device_id_' + str(get_rank()), save_graphs=True)
if wide_deep_config.sparse: if wide_deep_config.sparse:
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True) parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)

View File

@ -209,14 +209,20 @@ def _get_tf_dataset(data_dir,
shuffle=shuffle, shuffle=shuffle,
schema=schema, schema=schema,
num_parallel_workers=8) num_parallel_workers=8)
if batch_size <= 0:
raise ValueError("Batch size should be a positive int value, but found {}".format(str(batch_size)))
if batch_size % line_per_sample != 0:
raise ValueError(
"Batch size should be a multiple of {}, but found {}".format(str(line_per_sample), str(batch_size)))
data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True) data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True)
operations_list = [] operations_list = []
for key in columns_list: for key in columns_list:
operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key])) operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
print("ssssssssssssssssssssss---------------------" * 10) print("input_shape_dict start logging")
print(input_shape_dict) print(input_shape_dict)
print("---------------------" * 10) print("input_shape_dict end logging")
print(schema_dict) print(schema_dict)
def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u): def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):