add embedding cache st for ascend sparse mode
This commit is contained in:
parent
b85e398389
commit
30e63de701
|
@ -45,7 +45,7 @@ do
|
|||
rm -rf ${self_path}/server_$i/
|
||||
mkdir ${self_path}/server_$i/
|
||||
cd ${self_path}/server_$i/ || exit
|
||||
python ${self_path}/test_embedding_cache_distribute.py --device_target=$DEVICE_TARGET >server_$i.log 2>&1 &
|
||||
python ${self_path}/test_embedding_cache_distribute.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >server_$i.log 2>&1 &
|
||||
server_pids[${i}]=`echo $!`
|
||||
done
|
||||
|
||||
|
|
|
@ -21,19 +21,24 @@ export MS_WORKER_NUM=1
|
|||
export MS_SERVER_NUM=1
|
||||
export MS_SCHED_HOST=$2
|
||||
export MS_SCHED_PORT=$3
|
||||
export SPARSE=$4
|
||||
|
||||
if [[ ! -n "$4" ]]; then
|
||||
export SPARSE=0
|
||||
fi
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
rm -rf ${self_path}/sched/
|
||||
mkdir ${self_path}/sched/
|
||||
cd ${self_path}/sched/ || exit
|
||||
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET >sched.log 2>&1 &
|
||||
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >sched.log 2>&1 &
|
||||
sched_pid=`echo $!`
|
||||
|
||||
export MS_ROLE=MS_PSERVER
|
||||
rm -rf ${self_path}/server/
|
||||
mkdir ${self_path}/server/
|
||||
cd ${self_path}/server/ || exit
|
||||
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET >server.log 2>&1 &
|
||||
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE >server.log 2>&1 &
|
||||
server_pid=`echo $!`
|
||||
|
||||
export MS_ROLE=MS_WORKER
|
||||
|
@ -41,7 +46,7 @@ rm -rf ${self_path}/worker/
|
|||
mkdir ${self_path}/worker/
|
||||
cd ${self_path}/worker/ || exit
|
||||
export RANK_ID=0
|
||||
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET &>worker.log 2>&1 &
|
||||
python ${self_path}/test_embedding_cache_standalone.py --device_target=$DEVICE_TARGET --sparse=$SPARSE &>worker.log 2>&1 &
|
||||
worker_pid=`echo $!`
|
||||
|
||||
wait ${worker_pid}
|
||||
|
|
|
@ -23,13 +23,16 @@ from src.model import ModelExecutor
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="test_embedding_cache_standalone")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
sparse = bool(args.sparse)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
init()
|
||||
|
||||
dataset = create_dataset(resize_height=32, resize_width=32, scale=30.0)
|
||||
executor = ModelExecutor(dataset=dataset, sparse=False, vocab_cache_size=5000, in_channels=30720,
|
||||
executor = ModelExecutor(dataset=dataset, sparse=sparse, vocab_cache_size=5000, in_channels=30720,
|
||||
out_channels=12, input_shape=[32, 3, 32, 32])
|
||||
executor.run_embedding_cache()
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_embedding_cache_standalone_sparse_ascend():
|
||||
"""
|
||||
Feature: Test embedding cache feature on ascend with 1 worker, 1 server, and enable saparse mode.
|
||||
Description: Worker trains network containing embedding layers and enable embedding cache for saparse mode.
|
||||
Expectation: All process execute and exit normal.
|
||||
"""
|
||||
|
||||
self_path = os.path.split(os.path.realpath(__file__))[0]
|
||||
return_code = os.system(f"bash {self_path}/run_test_embedding_cache_standalone.sh Ascend 127.0.0.1 8022 1")
|
||||
if return_code != 0:
|
||||
os.system(f"echo '\n**************** Worker Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' {self_path}/worker*/worker*.log")
|
||||
os.system(f"echo '\n**************** Server Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' {self_path}/server/server.log")
|
||||
os.system(f"echo '\n**************** Scheduler Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' {self_path}/sched/sched.log")
|
||||
assert return_code == 0
|
Loading…
Reference in New Issue