add embedding cache st for ascend sparse mode

This commit is contained in:
lizhenyu 2022-11-28 10:28:42 +08:00
parent b85e398389
commit 30e63de701
4 changed files with 53 additions and 5 deletions

View File

@ -45,7 +45,7 @@ do
rm -rf ${self_path}/server_$i/
mkdir ${self_path}/server_$i/
cd ${self_path}/server_$i/ || exit
python ${self_path}/test_embedding_cache_distribute.py --device_target=$DEVICE_TARGET >server_$i.log 2>&1 &
python ${self_path}/test_embedding_cache_distribute.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >server_$i.log 2>&1 &
server_pids[${i}]=`echo $!`
done

View File

@ -21,19 +21,24 @@ export MS_WORKER_NUM=1
export MS_SERVER_NUM=1
export MS_SCHED_HOST=$2
export MS_SCHED_PORT=$3
export SPARSE=$4
if [[ ! -n "$4" ]]; then
export SPARSE=0
fi
export MS_ROLE=MS_SCHED
rm -rf ${self_path}/sched/
mkdir ${self_path}/sched/
cd ${self_path}/sched/ || exit
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET >sched.log 2>&1 &
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >sched.log 2>&1 &
sched_pid=`echo $!`
export MS_ROLE=MS_PSERVER
rm -rf ${self_path}/server/
mkdir ${self_path}/server/
cd ${self_path}/server/ || exit
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET >server.log 2>&1 &
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >server.log 2>&1 &
server_pid=`echo $!`
export MS_ROLE=MS_WORKER
@ -41,7 +46,7 @@ rm -rf ${self_path}/worker/
mkdir ${self_path}/worker/
cd ${self_path}/worker/ || exit
export RANK_ID=0
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET &>worker.log 2>&1 &
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE &>worker.log 2>&1 &
worker_pid=`echo $!`
wait ${worker_pid}

View File

@ -23,13 +23,16 @@ from src.model import ModelExecutor
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="test_embedding_cache_standalone")
parser.add_argument("--device_target", type=str, default="Ascend")
parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not")
args, _ = parser.parse_known_args()
device_target = args.device_target
sparse = bool(args.sparse)
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_ps_context(enable_ps=True)
init()
dataset = create_dataset(resize_height=32, resize_width=32, scale=30.0)
executor = ModelExecutor(dataset=dataset, sparse=False, vocab_cache_size=5000, in_channels=30720,
executor = ModelExecutor(dataset=dataset, sparse=sparse, vocab_cache_size=5000, in_channels=30720,
out_channels=12, input_shape=[32, 3, 32, 32])
executor.run_embedding_cache()

View File

@ -0,0 +1,40 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single
def test_embedding_cache_standalone_sparse_ascend():
"""
Feature: Test embedding cache feature on ascend with 1 worker, 1 server, and enable saparse mode.
Description: Worker trains network containing embedding layers and enable embedding cache for saparse mode.
Expectation: All process execute and exit normal.
"""
self_path = os.path.split(os.path.realpath(__file__))[0]
return_code = os.system(f"bash {self_path}/run_test_embedding_cache_standalone.sh Ascend 127.0.0.1 8022 1")
if return_code != 0:
os.system(f"echo '\n**************** Worker Log ****************'")
os.system(f"grep -E 'ERROR|Error|error' {self_path}/worker*/worker*.log")
os.system(f"echo '\n**************** Server Log ****************'")
os.system(f"grep -E 'ERROR|Error|error' {self_path}/server/server.log")
os.system(f"echo '\n**************** Scheduler Log ****************'")
os.system(f"grep -E 'ERROR|Error|error' {self_path}/sched/sched.log")
assert return_code == 0