Add ps ci test cases.

This commit is contained in:
ZPaC 2020-07-30 15:38:52 +08:00
parent 22927dc4f7
commit b10d4d6e0d
12 changed files with 480 additions and 24 deletions

View File

@ -25,9 +25,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
size_t axis = kShape4dDims - input_shape.size();
CPUKernelUtils::ExpandDimsTo4(&input_shape);
CPUKernelUtils::ExpandDimsTo4(&output_shape);
size_t axis = kShape2dDims - input_shape.size();
for (auto dim : input_shape) {
input_dims_ *= dim;
}
@ -40,6 +38,8 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
values.insert(values.end(), input_shape.begin(), input_shape.end());
values.insert(values.end(), indices_shape.begin(), indices_shape.end());
values.insert(values.end(), output_shape.begin(), output_shape.end());
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())};
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {

View File

@ -25,11 +25,15 @@ namespace mindspore {
namespace kernel {
namespace ps {
using mindspore::parallel::ps::Util;
constexpr int kAxis = 2;
constexpr int kAxis = 0;
void EmbeddingLookUpPSKernel::InitKernel(
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
input_shape_ = *(shape_vec[0]);
first_dim_size_ = input_shape_[0];
for (size_t i = 1; i < input_shape_.size(); ++i) {
outer_dim_size_ *= input_shape_[i];
}
auto indices_shape = *(shape_vec[1]);
indices_lens_ = 1;
for (auto shape : indices_shape) {
@ -49,7 +53,6 @@ void EmbeddingLookUpPSKernel::InitKernel(
size_t output_size =
std::accumulate(output_shape.begin(), output_shape.end(), sizeof(float), std::multiplies<size_t>());
output_size_list_.emplace_back(output_size);
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
}
void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {

View File

@ -77,7 +77,7 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
size_t worker_num) {
AddressPtr weight_addr = std::make_shared<kernel::Address>();
weight_addr->addr = weight->data();
weight_addr->size = weight->size();
weight_addr->size = weight->size() * sizeof(float);
AddressPtr m = std::make_shared<kernel::Address>();
m->addr = new float[weight->size()];
m->size = weight->size() * sizeof(float);
@ -156,7 +156,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
size_t worker_num) {
AddressPtr weight_addr = std::make_shared<kernel::Address>();
weight_addr->addr = weight->data();
weight_addr->size = weight->size();
weight_addr->size = weight->size() * sizeof(float);
AddressPtr accum = std::make_shared<kernel::Address>();
accum->addr = new float[weight->size()];
accum->size = weight->size() * sizeof(float);
@ -166,7 +166,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
}
AddressPtr linear = std::make_shared<kernel::Address>();
linear->addr = new float[weight->size()];
auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
}
@ -176,9 +176,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies<size_t>());
AddressPtr grad = std::make_shared<kernel::Address>();
grad->addr = new float[total_grad_size * worker_num];
auto ret1 = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float));
if (ret1 != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret1 << ")";
ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
grad->size = lens[0] * sizeof(float);
@ -187,10 +187,10 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies<size_t>());
AddressPtr indices = std::make_shared<kernel::Address>();
indices->addr = new float[total_indice_size * worker_num];
auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast<float *>(values.data()) + lens[0],
lens[1] * sizeof(float));
if (ret2 != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")";
ret = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast<float *>(values.data()) + lens[0],
lens[1] * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
indices->size = lens[1] * sizeof(int);

View File

@ -0,0 +1,75 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
#bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
# LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM
# SCHED_HOST SCHED_PORT ROLE
export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DATASET=$3
export RANK_TABLE_FILE=$4
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=$RANK_SIZE
export LOCAL_WORKER_NUM=$5
export LOCAL_SERVER_NUM=$6
export MS_SERVER_NUM=$7
export MS_SCHED_HOST=$8
export MS_SCHED_PORT=$9
export MS_ROLE=${10}
echo "=====Role is $MS_ROLE======"
if [ "$MS_ROLE" == "MS_SCHED" ];then
for((i=0;i<1;i++));
do
rm -rf ${execute_path}/sched_$i/
mkdir ${execute_path}/sched_$i/
cd ${execute_path}/sched_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >sched_$i.log 2>&1 &
done
fi
if [ "$MS_ROLE" == "MS_PSERVER" ];then
for((i=0;i<$LOCAL_SERVER_NUM;i++));
do
rm -rf ${execute_path}/server_$i/
mkdir ${execute_path}/server_$i/
cd ${execute_path}/server_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >server_$i.log 2>&1 &
done
fi
if [ "$MS_ROLE" == "MS_WORKER" ];then
for((i=0;i<$LOCAL_WORKER_NUM;i++));
do
rm -rf ${execute_path}/worker_$i/
mkdir ${execute_path}/worker_$i/
cd ${execute_path}/worker_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >worker_$i.log 2>&1 &
done
fi

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
self_path=$(dirname "${script_self}")
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
DEVICE_TARGET=$1
export MS_WORKER_NUM=$2
export MS_SERVER_NUM=$3
export MS_SCHED_HOST=$4
export MS_SCHED_PORT=$5
export MS_ROLE=MS_SCHED
for((i=0;i<1;i++));
do
rm -rf ${execute_path}/sched_$i/
mkdir ${execute_path}/sched_$i/
cd ${execute_path}/sched_$i/ || exit
python ${self_path}/../test_cmp_sparse_embedding.py &
done
export MS_ROLE=MS_PSERVER
for((i=0;i<$MS_SERVER_NUM;i++));
do
rm -rf ${execute_path}/server_$i/
mkdir ${execute_path}/server_$i/
cd ${execute_path}/server_$i/ || exit
python ${self_path}/../test_cmp_sparse_embedding.py &
done
export MS_ROLE=MS_WORKER
for((i=0;i<$MS_WORKER_NUM;i++));
do
rm -rf ${execute_path}/worker_$i/
mkdir ${execute_path}/worker_$i/
cd ${execute_path}/worker_$i/ || exit
python ${self_path}/../test_cmp_sparse_embedding.py &
done
wait $!
exit $?

View File

@ -0,0 +1,106 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore import Parameter
parser = argparse.ArgumentParser(description="test_sparse_embedding")
parser.add_argument("--device_target", type=str, default="Ascend")
args, _ = parser.parse_known_args()
device_target = args.device_target
context.set_context(
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
)
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.cast = P.Cast()
self.flatten = nn.Flatten()
self.embedding_table = Parameter(
initializer("normal", (16, 4), mstype.float32), name="embedding_table"
)
self.embedding = nn.EmbeddingLookup()
self.relu = nn.ReLU()
self.fc = fc_with_initialize(12, num_class)
def construct(self, x):
x = self.cast(x, mstype.int32)
x = self.embedding(self.embedding_table, x)
x = self.flatten(x)
x = self.fc(x)
return x
def do_sparse_embedding(ps=False):
epoch = 10
net = LeNet5(10)
if ps:
net.embedding_table.set_param_ps()
optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
criterion = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction="mean"
)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()
losses = []
for _ in range(epoch):
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
label = Tensor(np.random.randint(0, 9, (32), np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)
return losses
envs = os.environ
if __name__ == "__main__":
np.random.seed(0)
ps_loss = do_sparse_embedding(True)
if envs.get("MS_ROLE") == "MS_WORKER":
envs["MS_ROLE"] = ""
np.random.seed(0)
no_ps_loss = do_sparse_embedding()
envs["MS_ROLE"] = "MS_WORKER"
assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6)

View File

@ -0,0 +1,25 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_cmp_sparse_embedding():
return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8081")
assert return_code == 0

View File

@ -0,0 +1,56 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
self_path=$(dirname "${script_self}")
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
DEVICE_TARGET=$1
DATASET_PATH=$2
export MS_WORKER_NUM=$3
export MS_SERVER_NUM=$4
export MS_SCHED_HOST=$5
export MS_SCHED_PORT=$6
export MS_ROLE=MS_SCHED
for((i=0;i<1;i++));
do
rm -rf ${execute_path}/sched_$i/
mkdir ${execute_path}/sched_$i/
cd ${execute_path}/sched_$i/ || exit
python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH &
done
export MS_ROLE=MS_PSERVER
for((i=0;i<$MS_SERVER_NUM;i++));
do
rm -rf ${execute_path}/server_$i/
mkdir ${execute_path}/server_$i/
cd ${execute_path}/server_$i/ || exit
python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH &
done
export MS_ROLE=MS_WORKER
for((i=0;i<$MS_WORKER_NUM;i++));
do
rm -rf ${execute_path}/worker_$i/
mkdir ${execute_path}/worker_$i/
cd ${execute_path}/worker_$i/ || exit
python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH &
done
wait $!
exit $?

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_full_ps_ascend_lenet():
return_code = os.system(
"bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082"
)
assert return_code == 0

View File

@ -17,14 +17,16 @@ import os
# @pytest.mark.level0
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
# @pytest.mark.env_onecard
def test_full_ps_ascend_lenet():
return_code = os.system("bash run_full_ps_lenet.sh Ascend 1 1 127.0.0.1 8088")
# @pytest.mark.env_single
def test_multi_worker_full_ps_ascend_lenet():
return_code = os.system("bash shell_run_test.sh Ascend 8 1 127.0.0.1 8088")
assert return_code == 0
# @pytest.mark.level0
# @pytest.mark.platform_x86_gpu_training
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
# @pytest.mark.env_onecard
def test_full_ps_gpu_lenet():
return_code = os.system("bash run_full_ps_lenet.sh GPU 1 1 127.0.0.1 8088")
def test_full_ps_ascend_lenet():
return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8088")
assert return_code == 0

View File

@ -32,7 +32,7 @@ do
cd ${execute_path}/sched_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET &
python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET &
done
export MS_ROLE=MS_PSERVER
@ -43,7 +43,7 @@ do
cd ${execute_path}/server_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET &
python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET &
done
export MS_ROLE=MS_WORKER
@ -54,7 +54,7 @@ do
cd ${execute_path}/worker_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET &
python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET &
done
wait $!

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
parser = argparse.ArgumentParser(description="test_ps_lenet")
parser.add_argument("--device_target", type=str, default="Ascend")
args, _ = parser.parse_known_args()
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_init=weight,
has_bias=False,
pad_mode="valid",
)
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=3):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
if __name__ == "__main__":
epoch = 5
np.random.seed(0)
network = LeNet5(10)
network.set_param_ps()
criterion = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction="mean"
)
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_with_criterion = WithLossCell(network, criterion)
train_network = TrainOneStepCell(net_with_criterion, net_opt)
train_network.set_train()
losses = []
for _ in range(epoch):
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)