fix some network ncf scripts error

This commit is contained in:
anzhengqi 2021-12-23 19:11:53 +08:00
parent 7f85fe3026
commit 32e2d06350
3 changed files with 4 additions and 6 deletions

View File

@ -81,7 +81,7 @@ def test_eval():
if __name__ == '__main__':
devid = int(os.getenv('DEVICE_ID'))
devid = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Davinci",
save_graphs=True,

View File

@ -36,7 +36,6 @@ do
--dataset 'ml-1m' \
--train_epochs 50 \
--output_path './output/' \
--eval_file_name 'eval.log' \
--loss_file_name 'loss.log' \
--checkpoint_path './checkpoint/' \
--device_target="Ascend" \

View File

@ -14,9 +14,8 @@
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_transfer_ckpt_to_air.sh DATASET_PATH CKPT_FILE"
echo "for example: sh scripts/run_transfer_ckpt_to_air.sh /dataset_path /ncf.ckpt"
echo "sh scripts/run_transfer_ckpt_to_air.sh CKPT_FILE"
echo "for example: sh scripts/run_transfer_ckpt_to_air.sh /ncf.ckpt"
data_path=$1
ckpt_file=$2
python ./src/export.py --data_path $data_path --dataset 'ml-1m' --eval_batch_size 160000 --output_path './output/' --eval_file_name 'eval.log' --checkpoint_file_path $ckpt_file
python ./export.py --dataset 'ml-1m' --ckpt_file $ckpt_file