From 30e63de701d975e21ff20c0c8ceeb2bafe1f96d9 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Mon, 28 Nov 2022 10:28:42 +0800 Subject: [PATCH] add embedding cache st for ascend sparse mode --- .../run_test_embedding_cache_distribute.sh | 2 +- .../run_test_embedding_cache_standalone.sh | 11 +++-- .../test_embedding_cache_standalone.py | 5 ++- ...mbedding_cache_standalone_sparse_ascend.py | 40 +++++++++++++++++++ 4 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/st/ps/embedding_cache/test_entry_embedding_cache_standalone_sparse_ascend.py diff --git a/tests/st/ps/embedding_cache/run_test_embedding_cache_distribute.sh b/tests/st/ps/embedding_cache/run_test_embedding_cache_distribute.sh index 5ee6cc43f66..1445f27eed6 100644 --- a/tests/st/ps/embedding_cache/run_test_embedding_cache_distribute.sh +++ b/tests/st/ps/embedding_cache/run_test_embedding_cache_distribute.sh @@ -45,7 +45,7 @@ do rm -rf ${self_path}/server_$i/ mkdir ${self_path}/server_$i/ cd ${self_path}/server_$i/ || exit - python ${self_path}/test_embedding_cache_distribute.py --device_target=$DEVICE_TARGET >server_$i.log 2>&1 & + python ${self_path}/test_embedding_cache_distribute.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >server_$i.log 2>&1 & server_pids[${i}]=`echo $!` done diff --git a/tests/st/ps/embedding_cache/run_test_embedding_cache_standalone.sh b/tests/st/ps/embedding_cache/run_test_embedding_cache_standalone.sh index 0819d0d85f1..cf8dde344ad 100644 --- a/tests/st/ps/embedding_cache/run_test_embedding_cache_standalone.sh +++ b/tests/st/ps/embedding_cache/run_test_embedding_cache_standalone.sh @@ -21,19 +21,24 @@ export MS_WORKER_NUM=1 export MS_SERVER_NUM=1 export MS_SCHED_HOST=$2 export MS_SCHED_PORT=$3 +export SPARSE=$4 + +if [[ ! -n "$4" ]]; then + export SPARSE=0 +fi export MS_ROLE=MS_SCHED rm -rf ${self_path}/sched/ mkdir ${self_path}/sched/ cd ${self_path}/sched/ || exit -python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET >sched.log 2>&1 & +python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >sched.log 2>&1 & sched_pid=`echo $!` export MS_ROLE=MS_PSERVER rm -rf ${self_path}/server/ mkdir ${self_path}/server/ cd ${self_path}/server/ || exit -python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET >server.log 2>&1 & +python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >server.log 2>&1 & server_pid=`echo $!` export MS_ROLE=MS_WORKER @@ -41,7 +46,7 @@ rm -rf ${self_path}/worker/ mkdir ${self_path}/worker/ cd ${self_path}/worker/ || exit export RANK_ID=0 -python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET &>worker.log 2>&1 & +python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE &>worker.log 2>&1 & worker_pid=`echo $!` wait ${worker_pid} diff --git a/tests/st/ps/embedding_cache/test_embedding_cache_standalone.py b/tests/st/ps/embedding_cache/test_embedding_cache_standalone.py index 4c72c2bc607..4b0ffa49f04 100644 --- a/tests/st/ps/embedding_cache/test_embedding_cache_standalone.py +++ b/tests/st/ps/embedding_cache/test_embedding_cache_standalone.py @@ -23,13 +23,16 @@ from src.model import ModelExecutor if __name__ == "__main__": parser = argparse.ArgumentParser(description="test_embedding_cache_standalone") parser.add_argument("--device_target", type=str, default="Ascend") + parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not") args, _ = parser.parse_known_args() device_target = args.device_target + sparse = bool(args.sparse) + context.set_context(mode=context.GRAPH_MODE, device_target=device_target) context.set_ps_context(enable_ps=True) init() dataset = create_dataset(resize_height=32, resize_width=32, scale=30.0) - executor = ModelExecutor(dataset=dataset, sparse=False, vocab_cache_size=5000, in_channels=30720, + executor = ModelExecutor(dataset=dataset, sparse=sparse, vocab_cache_size=5000, in_channels=30720, out_channels=12, input_shape=[32, 3, 32, 32]) executor.run_embedding_cache() diff --git a/tests/st/ps/embedding_cache/test_entry_embedding_cache_standalone_sparse_ascend.py b/tests/st/ps/embedding_cache/test_entry_embedding_cache_standalone_sparse_ascend.py new file mode 100644 index 00000000000..617bc6898df --- /dev/null +++ b/tests/st/ps/embedding_cache/test_entry_embedding_cache_standalone_sparse_ascend.py @@ -0,0 +1,40 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_single +def test_embedding_cache_standalone_sparse_ascend(): + """ + Feature: Test embedding cache feature on ascend with 1 worker, 1 server, and enable saparse mode. + Description: Worker trains network containing embedding layers and enable embedding cache for saparse mode. + Expectation: All process execute and exit normal. + """ + + self_path = os.path.split(os.path.realpath(__file__))[0] + return_code = os.system(f"bash {self_path}/run_test_embedding_cache_standalone.sh Ascend 127.0.0.1 8022 1") + if return_code != 0: + os.system(f"echo '\n**************** Worker Log ****************'") + os.system(f"grep -E 'ERROR|Error|error' {self_path}/worker*/worker*.log") + os.system(f"echo '\n**************** Server Log ****************'") + os.system(f"grep -E 'ERROR|Error|error' {self_path}/server/server.log") + os.system(f"echo '\n**************** Scheduler Log ****************'") + os.system(f"grep -E 'ERROR|Error|error' {self_path}/sched/sched.log") + assert return_code == 0