Fix batch size check
This commit is contained in:
parent
11a3353239
commit
fe89ad2c49
|
@ -121,12 +121,12 @@ def train_and_eval(config):
|
||||||
model = Model(train_net, eval_network=eval_net,
|
model = Model(train_net, eval_network=eval_net,
|
||||||
metrics={"auc": auc_metric})
|
metrics={"auc": auc_metric})
|
||||||
|
|
||||||
eval_callback = EvalCallBack(
|
|
||||||
model, ds_eval, auc_metric, config)
|
|
||||||
|
|
||||||
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
|
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
|
||||||
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
|
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
|
||||||
|
|
||||||
|
eval_callback = EvalCallBack(
|
||||||
|
model, ds_eval, auc_metric, config)
|
||||||
|
|
||||||
callback = LossCallBack(config=config, per_print_times=20)
|
callback = LossCallBack(config=config, per_print_times=20)
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||||
keep_checkpoint_max=5, integrated_save=False)
|
keep_checkpoint_max=5, integrated_save=False)
|
||||||
|
@ -146,10 +146,11 @@ if __name__ == "__main__":
|
||||||
wide_deep_config = WideDeepConfig()
|
wide_deep_config = WideDeepConfig()
|
||||||
wide_deep_config.argparse_init()
|
wide_deep_config.argparse_init()
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=wide_deep_config.device_target, save_graphs=True)
|
device_target=wide_deep_config.device_target)
|
||||||
context.set_context(variable_memory_max_size="24GB")
|
context.set_context(variable_memory_max_size="24GB")
|
||||||
context.set_context(enable_sparse=True)
|
context.set_context(enable_sparse=True)
|
||||||
init()
|
init()
|
||||||
|
context.set_context(save_graphs_path='./graphs_of_device_id_' + str(get_rank()), save_graphs=True)
|
||||||
if wide_deep_config.sparse:
|
if wide_deep_config.sparse:
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
|
||||||
|
|
|
@ -209,14 +209,20 @@ def _get_tf_dataset(data_dir,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
num_parallel_workers=8)
|
num_parallel_workers=8)
|
||||||
|
if batch_size <= 0:
|
||||||
|
raise ValueError("Batch size should be a positive int value, but found {}".format(str(batch_size)))
|
||||||
|
if batch_size % line_per_sample != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Batch size should be a multiple of {}, but found {}".format(str(line_per_sample), str(batch_size)))
|
||||||
|
|
||||||
data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||||
|
|
||||||
operations_list = []
|
operations_list = []
|
||||||
for key in columns_list:
|
for key in columns_list:
|
||||||
operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
|
operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
|
||||||
print("ssssssssssssssssssssss---------------------" * 10)
|
print("input_shape_dict start logging")
|
||||||
print(input_shape_dict)
|
print(input_shape_dict)
|
||||||
print("---------------------" * 10)
|
print("input_shape_dict end logging")
|
||||||
print(schema_dict)
|
print(schema_dict)
|
||||||
|
|
||||||
def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
|
def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
|
||||||
|
|
Loading…
Reference in New Issue