diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index a97534f980..ea22397a3e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -25,9 +25,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - size_t axis = kShape4dDims - input_shape.size(); - CPUKernelUtils::ExpandDimsTo4(&input_shape); - CPUKernelUtils::ExpandDimsTo4(&output_shape); + size_t axis = kShape2dDims - input_shape.size(); for (auto dim : input_shape) { input_dims_ *= dim; } @@ -40,6 +38,8 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { values.insert(values.end(), input_shape.begin(), input_shape.end()); values.insert(values.end(), indices_shape.begin(), indices_shape.end()); values.insert(values.end(), output_shape.begin(), output_shape.end()); + MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape + << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; std::vector lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())}; const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc index 430f75f79e..4a36628dc7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -25,11 +25,15 @@ namespace mindspore { namespace kernel { namespace ps { using mindspore::parallel::ps::Util; -constexpr int kAxis = 2; +constexpr int kAxis = 0; void EmbeddingLookUpPSKernel::InitKernel( const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; input_shape_ = *(shape_vec[0]); + first_dim_size_ = input_shape_[0]; + for (size_t i = 1; i < input_shape_.size(); ++i) { + outer_dim_size_ *= input_shape_[i]; + } auto indices_shape = *(shape_vec[1]); indices_lens_ = 1; for (auto shape : indices_shape) { @@ -49,7 +53,6 @@ void EmbeddingLookUpPSKernel::InitKernel( size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), sizeof(float), std::multiplies()); output_size_list_.emplace_back(output_size); - CPUKernelUtils::ExpandDimsTo4(&input_shape_); } void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr>>> &shapes) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index 60f8d10f1c..e1d5ffb32a 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -77,7 +77,7 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, size_t worker_num) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); + weight_addr->size = weight->size() * sizeof(float); AddressPtr m = std::make_shared(); m->addr = new float[weight->size()]; m->size = weight->size() * sizeof(float); @@ -156,7 +156,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, size_t worker_num) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); + weight_addr->size = weight->size() * sizeof(float); AddressPtr accum = std::make_shared(); accum->addr = new float[weight->size()]; accum->size = weight->size() * sizeof(float); @@ -166,7 +166,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, } AddressPtr linear = std::make_shared(); linear->addr = new float[weight->size()]; - auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")"; } @@ -176,9 +176,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies()); AddressPtr grad = std::make_shared(); grad->addr = new float[total_grad_size * worker_num]; - auto ret1 = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); - if (ret1 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret1 << ")"; + ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } grad->size = lens[0] * sizeof(float); @@ -187,10 +187,10 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); AddressPtr indices = std::make_shared(); indices->addr = new float[total_indice_size * worker_num]; - auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], - lens[1] * sizeof(float)); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + ret = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], + lens[1] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } indices->size = lens[1] * sizeof(int); diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh new file mode 100644 index 0000000000..d2d885d420 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") + +#bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE +# LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM +# SCHED_HOST SCHED_PORT ROLE +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export RANK_TABLE_FILE=$4 + +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +export MS_WORKER_NUM=$RANK_SIZE +export LOCAL_WORKER_NUM=$5 +export LOCAL_SERVER_NUM=$6 +export MS_SERVER_NUM=$7 +export MS_SCHED_HOST=$8 +export MS_SCHED_PORT=$9 +export MS_ROLE=${10} +echo "=====Role is $MS_ROLE======" + + +if [ "$MS_ROLE" == "MS_SCHED" ];then +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >sched_$i.log 2>&1 & +done +fi + +if [ "$MS_ROLE" == "MS_PSERVER" ];then +for((i=0;i<$LOCAL_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >server_$i.log 2>&1 & +done +fi + +if [ "$MS_ROLE" == "MS_WORKER" ];then +for((i=0;i<$LOCAL_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >worker_$i.log 2>&1 & +done +fi diff --git a/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh b/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh new file mode 100644 index 0000000000..e1afc1dc14 --- /dev/null +++ b/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +self_path=$(dirname "${script_self}") +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +DEVICE_TARGET=$1 +export MS_WORKER_NUM=$2 +export MS_SERVER_NUM=$3 +export MS_SCHED_HOST=$4 +export MS_SCHED_PORT=$5 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + python ${self_path}/../test_cmp_sparse_embedding.py & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + python ${self_path}/../test_cmp_sparse_embedding.py & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + python ${self_path}/../test_cmp_sparse_embedding.py & +done + +wait $! +exit $? diff --git a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py new file mode 100644 index 0000000000..c08b5b9936 --- /dev/null +++ b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py @@ -0,0 +1,106 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Adam +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore import Parameter + +parser = argparse.ArgumentParser(description="test_sparse_embedding") +parser.add_argument("--device_target", type=str, default="Ascend") +args, _ = parser.parse_known_args() +device_target = args.device_target +context.set_context( + mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True +) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10): + super(LeNet5, self).__init__() + self.cast = P.Cast() + self.flatten = nn.Flatten() + self.embedding_table = Parameter( + initializer("normal", (16, 4), mstype.float32), name="embedding_table" + ) + self.embedding = nn.EmbeddingLookup() + self.relu = nn.ReLU() + self.fc = fc_with_initialize(12, num_class) + + def construct(self, x): + x = self.cast(x, mstype.int32) + x = self.embedding(self.embedding_table, x) + x = self.flatten(x) + x = self.fc(x) + return x + + +def do_sparse_embedding(ps=False): + epoch = 10 + net = LeNet5(10) + if ps: + net.embedding_table.set_param_ps() + + optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) + optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") + criterion = nn.SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction="mean" + ) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.random.randint(0, 15, (32, 3), np.int32)) + label = Tensor(np.random.randint(0, 9, (32), np.int32)) + loss = train_network(data, label).asnumpy() + losses.append(loss) + print(losses) + return losses + + +envs = os.environ +if __name__ == "__main__": + np.random.seed(0) + ps_loss = do_sparse_embedding(True) + + if envs.get("MS_ROLE") == "MS_WORKER": + envs["MS_ROLE"] = "" + np.random.seed(0) + no_ps_loss = do_sparse_embedding() + envs["MS_ROLE"] = "MS_WORKER" + + assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6) diff --git a/tests/st/ps/cmp_sparse_embedding/test_entry_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_entry_cmp_sparse_embedding.py new file mode 100644 index 0000000000..bc400c963c --- /dev/null +++ b/tests/st/ps/cmp_sparse_embedding/test_entry_cmp_sparse_embedding.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_cmp_sparse_embedding(): + return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8081") + assert return_code == 0 diff --git a/tests/st/ps/full_ps/shell_run_test.sh b/tests/st/ps/full_ps/shell_run_test.sh new file mode 100644 index 0000000000..8222e76888 --- /dev/null +++ b/tests/st/ps/full_ps/shell_run_test.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +self_path=$(dirname "${script_self}") +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +DEVICE_TARGET=$1 +DATASET_PATH=$2 +export MS_WORKER_NUM=$3 +export MS_SERVER_NUM=$4 +export MS_SCHED_HOST=$5 +export MS_SCHED_PORT=$6 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & +done + +wait $! +exit $? diff --git a/tests/st/ps/full_ps/test_entry_full_ps_lenet.py b/tests/st/ps/full_ps/test_entry_full_ps_lenet.py new file mode 100644 index 0000000000..9d11a52bc6 --- /dev/null +++ b/tests/st/ps/full_ps/test_entry_full_ps_lenet.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_full_ps_ascend_lenet(): + return_code = os.system( + "bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082" + ) + assert return_code == 0 diff --git a/tests/st/ps/full_ps/test_run.py b/tests/st/ps/multi_worker_full_ps/entry.py similarity index 75% rename from tests/st/ps/full_ps/test_run.py rename to tests/st/ps/multi_worker_full_ps/entry.py index 9cf70102e5..e54623144d 100644 --- a/tests/st/ps/full_ps/test_run.py +++ b/tests/st/ps/multi_worker_full_ps/entry.py @@ -17,14 +17,16 @@ import os # @pytest.mark.level0 # @pytest.mark.platform_arm_ascend_training # @pytest.mark.platform_x86_ascend_training -# @pytest.mark.env_onecard -def test_full_ps_ascend_lenet(): - return_code = os.system("bash run_full_ps_lenet.sh Ascend 1 1 127.0.0.1 8088") +# @pytest.mark.env_single +def test_multi_worker_full_ps_ascend_lenet(): + return_code = os.system("bash shell_run_test.sh Ascend 8 1 127.0.0.1 8088") assert return_code == 0 + # @pytest.mark.level0 -# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.platform_arm_ascend_training +# @pytest.mark.platform_x86_ascend_training # @pytest.mark.env_onecard -def test_full_ps_gpu_lenet(): - return_code = os.system("bash run_full_ps_lenet.sh GPU 1 1 127.0.0.1 8088") +def test_full_ps_ascend_lenet(): + return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8088") assert return_code == 0 diff --git a/tests/st/ps/full_ps/run_full_ps_lenet.sh b/tests/st/ps/multi_worker_full_ps/shell_run_test.sh similarity index 84% rename from tests/st/ps/full_ps/run_full_ps_lenet.sh rename to tests/st/ps/multi_worker_full_ps/shell_run_test.sh index c7d82d7040..47cb5a4dcf 100644 --- a/tests/st/ps/full_ps/run_full_ps_lenet.sh +++ b/tests/st/ps/multi_worker_full_ps/shell_run_test.sh @@ -32,7 +32,7 @@ do cd ${execute_path}/sched_$i/ || exit export RANK_ID=$i export DEVICE_ID=$i - python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET & + python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & done export MS_ROLE=MS_PSERVER @@ -43,7 +43,7 @@ do cd ${execute_path}/server_$i/ || exit export RANK_ID=$i export DEVICE_ID=$i - python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET & + python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & done export MS_ROLE=MS_WORKER @@ -54,7 +54,7 @@ do cd ${execute_path}/worker_$i/ || exit export RANK_ID=$i export DEVICE_ID=$i - python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET & + python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & done wait $! diff --git a/tests/st/ps/multi_worker_full_ps/test_multi_worker_full_ps_lenet.py b/tests/st/ps/multi_worker_full_ps/test_multi_worker_full_ps_lenet.py new file mode 100644 index 0000000000..c08f923e0d --- /dev/null +++ b/tests/st/ps/multi_worker_full_ps/test_multi_worker_full_ps_lenet.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell + +parser = argparse.ArgumentParser(description="test_ps_lenet") +parser.add_argument("--device_target", type=str, default="Ascend") +args, _ = parser.parse_known_args() +device_target = args.device_target +context.set_context(mode=context.GRAPH_MODE, device_target=device_target) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_init=weight, + has_bias=False, + pad_mode="valid", + ) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=3): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +if __name__ == "__main__": + epoch = 5 + np.random.seed(0) + network = LeNet5(10) + network.set_param_ps() + criterion = nn.SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction="mean" + ) + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + + net_with_criterion = WithLossCell(network, criterion) + train_network = TrainOneStepCell(net_with_criterion, net_opt) + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) + label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32)) + loss = train_network(data, label).asnumpy() + losses.append(loss) + print(losses)