diff --git a/model_zoo/wide_and_deep/tools/config.py b/model_zoo/wide_and_deep/tools/config.py new file mode 100644 index 00000000000..8d87904be01 --- /dev/null +++ b/model_zoo/wide_and_deep/tools/config.py @@ -0,0 +1,91 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" config. """ +import argparse + + +def argparse_init(): + """ + argparse_init + """ + parser = argparse.ArgumentParser(description='WideDeep') + parser.add_argument("--data_path", type=str, default="./test_raw_data/") + parser.add_argument("--epochs", type=int, default=15) + parser.add_argument("--batch_size", type=int, default=10000) + parser.add_argument("--eval_batch_size", type=int, default=15) + parser.add_argument("--field_size", type=int, default=39) + parser.add_argument("--vocab_size", type=int, default=184965) + parser.add_argument("--emb_dim", type=int, default=80) + parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) + parser.add_argument("--deep_layer_act", type=str, default='relu') + parser.add_argument("--keep_prob", type=float, default=1.0) + + parser.add_argument("--output_path", type=str, default="./output/") + parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") + parser.add_argument("--eval_file_name", type=str, default="eval.log") + parser.add_argument("--loss_file_name", type=str, default="loss.log") + return parser + + +class Config_WideDeep(): + """ + Config_WideDeep + """ + def __init__(self): + self.data_path = "./test_raw_data/" + self.epochs = 15 + self.batch_size = 10000 + self.eval_batch_size = 10000 + self.field_size = 39 + self.vocab_size = 184965 + self.emb_dim = 80 + self.deep_layer_dim = [1024, 512, 256, 128] + self.deep_layer_act = 'relu' + self.weight_bias_init = ['normal', 'normal'] + self.emb_init = 'normal' + self.init_args = [-0.01, 0.01] + self.dropout_flag = False + self.keep_prob = 1.0 + self.l2_coef = 8e-5 + + self.output_path = "./output" + self.eval_file_name = "eval.log" + self.loss_file_name = "loss.log" + self.ckpt_path = "./checkpoints/" + + def argparse_init(self): + """ + argparse_init + """ + parser = argparse_init() + args, _ = parser.parse_known_args() + self.epochs = args.epochs + self.batch_size = args.batch_size + self.eval_batch_size = args.eval_batch_size + self.field_size = args.field_size + self.vocab_size = args.vocab_size + self.emb_dim = args.emb_dim + self.deep_layer_dim = args.deep_layer_dim + self.deep_layer_act = args.deep_layer_act + self.keep_prob = args.keep_prob + self.weight_bias_init = ['normal', 'normal'] + self.emb_init = 'normal' + self.init_args = [-0.01, 0.01] + self.dropout_flag = False + self.l2_coef = 8e-5 + + self.output_path = args.output_path + self.eval_file_name = args.eval_file_name + self.loss_file_name = args.loss_file_name + self.ckpt_path = args.ckpt_path diff --git a/model_zoo/wide_and_deep/tools/train_and_test.py b/model_zoo/wide_and_deep/tools/train_and_test.py new file mode 100644 index 00000000000..9f08377c75e --- /dev/null +++ b/model_zoo/wide_and_deep/tools/train_and_test.py @@ -0,0 +1,97 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" test_training """ +import os + +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig + +from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from wide_deep.utils.callbacks import LossCallBack, EvalCallBack +from wide_deep.data.datasets import create_dataset +from wide_deep.utils.metrics import AUCMetric +from tools.config import Config_WideDeep + +context.set_context(mode=context.GRAPH_MODE, device_target="Davinci") + + +def get_WideDeep_net(config): + WideDeep_net = WideDeepModel(config) + + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def test_train_eval(config): + """ + test_train_eval + """ + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) + ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + callback = LossCallBack() + ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) + + out = model.eval(ds_eval) + print("=====" * 5 + "model.eval() initialized: {}".format(out)) + model.train(epochs, ds_train, callbacks=[eval_callback, callback, ckpoint_cb]) + + +if __name__ == "__main__": + wide_deep_config = Config_WideDeep() + wide_deep_config.argparse_init() + + test_train_eval(wide_deep_config)