forked from OSSInnovation/mindspore
Add ps ci test cases.
This commit is contained in:
parent
22927dc4f7
commit
b10d4d6e0d
|
@ -25,9 +25,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
size_t axis = kShape4dDims - input_shape.size();
|
||||
CPUKernelUtils::ExpandDimsTo4(&input_shape);
|
||||
CPUKernelUtils::ExpandDimsTo4(&output_shape);
|
||||
size_t axis = kShape2dDims - input_shape.size();
|
||||
for (auto dim : input_shape) {
|
||||
input_dims_ *= dim;
|
||||
}
|
||||
|
@ -40,6 +38,8 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
values.insert(values.end(), input_shape.begin(), input_shape.end());
|
||||
values.insert(values.end(), indices_shape.begin(), indices_shape.end());
|
||||
values.insert(values.end(), output_shape.begin(), output_shape.end());
|
||||
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
|
||||
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
|
||||
std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())};
|
||||
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
|
||||
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
|
||||
|
|
|
@ -25,11 +25,15 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace ps {
|
||||
using mindspore::parallel::ps::Util;
|
||||
constexpr int kAxis = 2;
|
||||
constexpr int kAxis = 0;
|
||||
void EmbeddingLookUpPSKernel::InitKernel(
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
||||
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
|
||||
input_shape_ = *(shape_vec[0]);
|
||||
first_dim_size_ = input_shape_[0];
|
||||
for (size_t i = 1; i < input_shape_.size(); ++i) {
|
||||
outer_dim_size_ *= input_shape_[i];
|
||||
}
|
||||
auto indices_shape = *(shape_vec[1]);
|
||||
indices_lens_ = 1;
|
||||
for (auto shape : indices_shape) {
|
||||
|
@ -49,7 +53,6 @@ void EmbeddingLookUpPSKernel::InitKernel(
|
|||
size_t output_size =
|
||||
std::accumulate(output_shape.begin(), output_shape.end(), sizeof(float), std::multiplies<size_t>());
|
||||
output_size_list_.emplace_back(output_size);
|
||||
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
||||
}
|
||||
|
||||
void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
||||
|
|
|
@ -77,7 +77,7 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
size_t worker_num) {
|
||||
AddressPtr weight_addr = std::make_shared<kernel::Address>();
|
||||
weight_addr->addr = weight->data();
|
||||
weight_addr->size = weight->size();
|
||||
weight_addr->size = weight->size() * sizeof(float);
|
||||
AddressPtr m = std::make_shared<kernel::Address>();
|
||||
m->addr = new float[weight->size()];
|
||||
m->size = weight->size() * sizeof(float);
|
||||
|
@ -156,7 +156,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
size_t worker_num) {
|
||||
AddressPtr weight_addr = std::make_shared<kernel::Address>();
|
||||
weight_addr->addr = weight->data();
|
||||
weight_addr->size = weight->size();
|
||||
weight_addr->size = weight->size() * sizeof(float);
|
||||
AddressPtr accum = std::make_shared<kernel::Address>();
|
||||
accum->addr = new float[weight->size()];
|
||||
accum->size = weight->size() * sizeof(float);
|
||||
|
@ -166,7 +166,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
}
|
||||
AddressPtr linear = std::make_shared<kernel::Address>();
|
||||
linear->addr = new float[weight->size()];
|
||||
auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
|
||||
int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
@ -176,9 +176,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies<size_t>());
|
||||
AddressPtr grad = std::make_shared<kernel::Address>();
|
||||
grad->addr = new float[total_grad_size * worker_num];
|
||||
auto ret1 = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float));
|
||||
if (ret1 != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret1 << ")";
|
||||
ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
grad->size = lens[0] * sizeof(float);
|
||||
|
||||
|
@ -187,10 +187,10 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies<size_t>());
|
||||
AddressPtr indices = std::make_shared<kernel::Address>();
|
||||
indices->addr = new float[total_indice_size * worker_num];
|
||||
auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast<float *>(values.data()) + lens[0],
|
||||
lens[1] * sizeof(float));
|
||||
if (ret2 != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")";
|
||||
ret = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast<float *>(values.data()) + lens[0],
|
||||
lens[1] * sizeof(float));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
indices->size = lens[1] * sizeof(int);
|
||||
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
|
||||
#bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
|
||||
# LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM
|
||||
# SCHED_HOST SCHED_PORT ROLE
|
||||
export RANK_SIZE=$1
|
||||
export EPOCH_SIZE=$2
|
||||
export DATASET=$3
|
||||
export RANK_TABLE_FILE=$4
|
||||
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
export MS_WORKER_NUM=$RANK_SIZE
|
||||
export LOCAL_WORKER_NUM=$5
|
||||
export LOCAL_SERVER_NUM=$6
|
||||
export MS_SERVER_NUM=$7
|
||||
export MS_SCHED_HOST=$8
|
||||
export MS_SCHED_PORT=$9
|
||||
export MS_ROLE=${10}
|
||||
echo "=====Role is $MS_ROLE======"
|
||||
|
||||
|
||||
if [ "$MS_ROLE" == "MS_SCHED" ];then
|
||||
for((i=0;i<1;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/sched_$i/
|
||||
mkdir ${execute_path}/sched_$i/
|
||||
cd ${execute_path}/sched_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >sched_$i.log 2>&1 &
|
||||
done
|
||||
fi
|
||||
|
||||
if [ "$MS_ROLE" == "MS_PSERVER" ];then
|
||||
for((i=0;i<$LOCAL_SERVER_NUM;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/server_$i/
|
||||
mkdir ${execute_path}/server_$i/
|
||||
cd ${execute_path}/server_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >server_$i.log 2>&1 &
|
||||
done
|
||||
fi
|
||||
|
||||
if [ "$MS_ROLE" == "MS_WORKER" ];then
|
||||
for((i=0;i<$LOCAL_WORKER_NUM;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/worker_$i/
|
||||
mkdir ${execute_path}/worker_$i/
|
||||
cd ${execute_path}/worker_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >worker_$i.log 2>&1 &
|
||||
done
|
||||
fi
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
self_path=$(dirname "${script_self}")
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
DEVICE_TARGET=$1
|
||||
export MS_WORKER_NUM=$2
|
||||
export MS_SERVER_NUM=$3
|
||||
export MS_SCHED_HOST=$4
|
||||
export MS_SCHED_PORT=$5
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
for((i=0;i<1;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/sched_$i/
|
||||
mkdir ${execute_path}/sched_$i/
|
||||
cd ${execute_path}/sched_$i/ || exit
|
||||
python ${self_path}/../test_cmp_sparse_embedding.py &
|
||||
done
|
||||
|
||||
export MS_ROLE=MS_PSERVER
|
||||
for((i=0;i<$MS_SERVER_NUM;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/server_$i/
|
||||
mkdir ${execute_path}/server_$i/
|
||||
cd ${execute_path}/server_$i/ || exit
|
||||
python ${self_path}/../test_cmp_sparse_embedding.py &
|
||||
done
|
||||
|
||||
export MS_ROLE=MS_WORKER
|
||||
for((i=0;i<$MS_WORKER_NUM;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/worker_$i/
|
||||
mkdir ${execute_path}/worker_$i/
|
||||
cd ${execute_path}/worker_$i/ || exit
|
||||
python ${self_path}/../test_cmp_sparse_embedding.py &
|
||||
done
|
||||
|
||||
wait $!
|
||||
exit $?
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore import Parameter
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_sparse_embedding")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True
|
||||
)
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
def __init__(self, num_class=10):
|
||||
super(LeNet5, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.flatten = nn.Flatten()
|
||||
self.embedding_table = Parameter(
|
||||
initializer("normal", (16, 4), mstype.float32), name="embedding_table"
|
||||
)
|
||||
self.embedding = nn.EmbeddingLookup()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc = fc_with_initialize(12, num_class)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.cast(x, mstype.int32)
|
||||
x = self.embedding(self.embedding_table, x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def do_sparse_embedding(ps=False):
|
||||
epoch = 10
|
||||
net = LeNet5(10)
|
||||
if ps:
|
||||
net.embedding_table.set_param_ps()
|
||||
|
||||
optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(
|
||||
is_grad=False, sparse=True, reduction="mean"
|
||||
)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer)
|
||||
train_network.set_train()
|
||||
losses = []
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.random.randint(0, 15, (32, 3), np.int32))
|
||||
label = Tensor(np.random.randint(0, 9, (32), np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
||||
return losses
|
||||
|
||||
|
||||
envs = os.environ
|
||||
if __name__ == "__main__":
|
||||
np.random.seed(0)
|
||||
ps_loss = do_sparse_embedding(True)
|
||||
|
||||
if envs.get("MS_ROLE") == "MS_WORKER":
|
||||
envs["MS_ROLE"] = ""
|
||||
np.random.seed(0)
|
||||
no_ps_loss = do_sparse_embedding()
|
||||
envs["MS_ROLE"] = "MS_WORKER"
|
||||
|
||||
assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6)
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cmp_sparse_embedding():
|
||||
return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8081")
|
||||
assert return_code == 0
|
|
@ -0,0 +1,56 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
self_path=$(dirname "${script_self}")
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
DEVICE_TARGET=$1
|
||||
DATASET_PATH=$2
|
||||
export MS_WORKER_NUM=$3
|
||||
export MS_SERVER_NUM=$4
|
||||
export MS_SCHED_HOST=$5
|
||||
export MS_SCHED_PORT=$6
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
for((i=0;i<1;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/sched_$i/
|
||||
mkdir ${execute_path}/sched_$i/
|
||||
cd ${execute_path}/sched_$i/ || exit
|
||||
python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH &
|
||||
done
|
||||
|
||||
export MS_ROLE=MS_PSERVER
|
||||
for((i=0;i<$MS_SERVER_NUM;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/server_$i/
|
||||
mkdir ${execute_path}/server_$i/
|
||||
cd ${execute_path}/server_$i/ || exit
|
||||
python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH &
|
||||
done
|
||||
|
||||
export MS_ROLE=MS_WORKER
|
||||
for((i=0;i<$MS_WORKER_NUM;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/worker_$i/
|
||||
mkdir ${execute_path}/worker_$i/
|
||||
cd ${execute_path}/worker_$i/ || exit
|
||||
python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH &
|
||||
done
|
||||
|
||||
wait $!
|
||||
exit $?
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_full_ps_ascend_lenet():
|
||||
return_code = os.system(
|
||||
"bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082"
|
||||
)
|
||||
assert return_code == 0
|
|
@ -17,14 +17,16 @@ import os
|
|||
# @pytest.mark.level0
|
||||
# @pytest.mark.platform_arm_ascend_training
|
||||
# @pytest.mark.platform_x86_ascend_training
|
||||
# @pytest.mark.env_onecard
|
||||
def test_full_ps_ascend_lenet():
|
||||
return_code = os.system("bash run_full_ps_lenet.sh Ascend 1 1 127.0.0.1 8088")
|
||||
# @pytest.mark.env_single
|
||||
def test_multi_worker_full_ps_ascend_lenet():
|
||||
return_code = os.system("bash shell_run_test.sh Ascend 8 1 127.0.0.1 8088")
|
||||
assert return_code == 0
|
||||
|
||||
|
||||
# @pytest.mark.level0
|
||||
# @pytest.mark.platform_x86_gpu_training
|
||||
# @pytest.mark.platform_arm_ascend_training
|
||||
# @pytest.mark.platform_x86_ascend_training
|
||||
# @pytest.mark.env_onecard
|
||||
def test_full_ps_gpu_lenet():
|
||||
return_code = os.system("bash run_full_ps_lenet.sh GPU 1 1 127.0.0.1 8088")
|
||||
def test_full_ps_ascend_lenet():
|
||||
return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8088")
|
||||
assert return_code == 0
|
|
@ -32,7 +32,7 @@ do
|
|||
cd ${execute_path}/sched_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET &
|
||||
python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET &
|
||||
done
|
||||
|
||||
export MS_ROLE=MS_PSERVER
|
||||
|
@ -43,7 +43,7 @@ do
|
|||
cd ${execute_path}/server_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET &
|
||||
python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET &
|
||||
done
|
||||
|
||||
export MS_ROLE=MS_WORKER
|
||||
|
@ -54,7 +54,7 @@ do
|
|||
cd ${execute_path}/worker_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET &
|
||||
python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET &
|
||||
done
|
||||
|
||||
wait $!
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_ps_lenet")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
weight_init=weight,
|
||||
has_bias=False,
|
||||
pad_mode="valid",
|
||||
)
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
def __init__(self, num_class=10, channel=3):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = conv(channel, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
epoch = 5
|
||||
np.random.seed(0)
|
||||
network = LeNet5(10)
|
||||
network.set_param_ps()
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(
|
||||
is_grad=False, sparse=True, reduction="mean"
|
||||
)
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
|
||||
net_with_criterion = WithLossCell(network, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, net_opt)
|
||||
train_network.set_train()
|
||||
losses = []
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
Loading…
Reference in New Issue