Fix batch size check
This commit is contained in:
parent
11a3353239
commit
fe89ad2c49
|
@ -121,12 +121,12 @@ def train_and_eval(config):
|
|||
model = Model(train_net, eval_network=eval_net,
|
||||
metrics={"auc": auc_metric})
|
||||
|
||||
eval_callback = EvalCallBack(
|
||||
model, ds_eval, auc_metric, config)
|
||||
|
||||
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
|
||||
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
|
||||
|
||||
eval_callback = EvalCallBack(
|
||||
model, ds_eval, auc_metric, config)
|
||||
|
||||
callback = LossCallBack(config=config, per_print_times=20)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||
keep_checkpoint_max=5, integrated_save=False)
|
||||
|
@ -146,10 +146,11 @@ if __name__ == "__main__":
|
|||
wide_deep_config = WideDeepConfig()
|
||||
wide_deep_config.argparse_init()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
device_target=wide_deep_config.device_target)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
init()
|
||||
context.set_context(save_graphs_path='./graphs_of_device_id_' + str(get_rank()), save_graphs=True)
|
||||
if wide_deep_config.sparse:
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
|
||||
|
|
|
@ -209,14 +209,20 @@ def _get_tf_dataset(data_dir,
|
|||
shuffle=shuffle,
|
||||
schema=schema,
|
||||
num_parallel_workers=8)
|
||||
if batch_size <= 0:
|
||||
raise ValueError("Batch size should be a positive int value, but found {}".format(str(batch_size)))
|
||||
if batch_size % line_per_sample != 0:
|
||||
raise ValueError(
|
||||
"Batch size should be a multiple of {}, but found {}".format(str(line_per_sample), str(batch_size)))
|
||||
|
||||
data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
|
||||
operations_list = []
|
||||
for key in columns_list:
|
||||
operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
|
||||
print("ssssssssssssssssssssss---------------------" * 10)
|
||||
print("input_shape_dict start logging")
|
||||
print(input_shape_dict)
|
||||
print("---------------------" * 10)
|
||||
print("input_shape_dict end logging")
|
||||
print(schema_dict)
|
||||
|
||||
def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
|
||||
|
|
Loading…
Reference in New Issue