diff --git a/build.sh b/build.sh index 4427be77a4f..ab528257f74 100755 --- a/build.sh +++ b/build.sh @@ -391,6 +391,7 @@ checkopts() ENABLE_D="on" ENABLE_ACL="on" ENABLE_CPU="on" + ENABLE_MPI="on" else echo "Invalid value ${DEVICE_VERSION} for option -V" usage diff --git a/cmake/package.cmake b/cmake/package.cmake index 2e4dd74e6ca..9213f0d5427 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -163,6 +163,13 @@ if(ENABLE_MPI) COMPONENT mindspore ) endif() + if(ENABLE_D) + install( + TARGETS _ascend_mpi + DESTINATION ${INSTALL_BASE_DIR} + COMPONENT mindspore + ) + endif() endif() if(ENABLE_GPU) @@ -180,6 +187,16 @@ if(ENABLE_GPU) ) endif() +if(ENABLE_D) + if(ENABLE_MPI) + install( + TARGETS ascend_collective + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) + endif() +endif() + if(ENABLE_CPU AND NOT WIN32) install( TARGETS ps_cache diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 33ebdc3887f..6419ea5a9f9 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -420,4 +420,10 @@ if(MODE_ASCEND_ALL) target_link_libraries(_c_expression PRIVATE ${adump_server}) endif() +if(ENABLE_D) + if(ENABLE_MPI) + set_target_properties(_ascend_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) + endif() +endif() + add_subdirectory(cxx_api) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 4fc01a93177..93a871326cc 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -72,6 +72,7 @@ #include "transform/graph_ir/df_graph_manager.h" #include "transform/graph_ir/op_adapter_map.h" #include "runtime/device/ascend/profiling/profiling_manager.h" +#include "runtime/device/ascend/distribute/ascend_collective.h" #endif #ifdef ENABLE_DUMP_IR #include "debug/rdr/running_data_recorder.h" @@ -91,6 +92,7 @@ using mindspore::abstract::AbstractTuplePtr; #ifdef ENABLE_D using mindspore::device::ascend::ProfilingManager; +using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup; #endif const char IR_TYPE_ANF[] = "anf_ir"; @@ -1216,6 +1218,23 @@ void InitHccl() { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); uint32_t device_id = ms_context->get_param(MS_CTX_DEVICE_ID); +#if ENABLE_D + bool task_sink = true; + auto single_op = std::getenv(kAttrGraphOpRun); + if (single_op && std::string(single_op) == "1") { + task_sink = false; + } + auto mode = ms_context->get_param(MS_CTX_EXECUTION_MODE); + if (!task_sink && mode == kGraphMode) { + MS_LOG(INFO) << "mpi collective init."; + if (!HcclCollectiveGroup::instance().InitCollective()) { + MS_LOG(EXCEPTION) << "HcclCollectiveGroup init failed."; + } + device_id = IntToUint(HcclCollectiveGroup::instance().GetDeviceId()); + ms_context->set_param(MS_CTX_DEVICE_ID, device_id); + ms_context->set_param(MS_CTX_ENABLE_TASK_SINK, false); + } +#endif std::string device_name = ms_context->get_param(MS_CTX_DEVICE_TARGET); ms_context->set_param(MS_CTX_ENABLE_HCCL, true); if (ms_context->backend_policy() == "ms" && diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index b47a8ed77c8..014982d4d23 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -268,6 +268,10 @@ void AscendKernelRuntime::ReleaseDeviceRes() { (void)ResetDevice(device_id); (void)ProfilingManager::GetInstance().StopProfiling(); current_graph_ = nullptr; + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode && + !context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { + HcclCollectiveGroup::instance().FinalizeCollective(); + } MS_LOG(INFO) << "Ascend finalize end"; } diff --git a/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.cc b/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.cc index b9719aab1e6..92257727fd5 100644 --- a/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.cc +++ b/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.cc @@ -16,6 +16,9 @@ #include "runtime/device/ascend/distribute/mpi_pycc.h" #include +#include +#include + namespace mindspore { namespace device { namespace ascend { @@ -28,12 +31,16 @@ MpiPycc &MpiPycc::instance() { int MpiPycc::GetDeviceID() { return GetDeviceId(); } int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); } int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); } +void MpiPycc::CreateGroup(const std::string &group, const std::vector &ranks) { + CreateCommForGroup(group, ranks); +} // cppcheck-suppress syntaxError PYBIND11_MODULE(_ascend_mpi, mpi_initializer) { mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id"); mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id"); mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size"); + mpi_initializer.def("create_group", &MpiPycc::CreateGroup, "create group"); } } // namespace collective } // namespace ascend diff --git a/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.h b/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.h index 5282a3bbc23..e63e075d19c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.h +++ b/mindspore/ccsrc/runtime/device/ascend/distribute/mpi_pycc.h @@ -18,6 +18,7 @@ #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H #include +#include #include "runtime/device/ascend/distribute/collective_group_wrapper.h" namespace mindspore { @@ -32,6 +33,7 @@ class MpiPycc { static int GetDeviceID(); static int GetRankId(const std::string &group); static int GetRankSize(const std::string &group); + static void CreateGroup(const std::string &group, const std::vector &ranks); private: MpiPycc() = default; diff --git a/mindspore/ccsrc/utils/comm_manager.cc b/mindspore/ccsrc/utils/comm_manager.cc index b8db0812bed..fff36463122 100644 --- a/mindspore/ccsrc/utils/comm_manager.cc +++ b/mindspore/ccsrc/utils/comm_manager.cc @@ -22,6 +22,9 @@ #ifndef NO_DLIB #include "runtime/hccl_adapter/hccl_adapter.h" +#include "hccl/hcom.h" +#include "runtime/device/ascend/distribute/ascend_collective.h" +using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup; #endif #if defined(ENABLE_GPU) @@ -69,9 +72,17 @@ bool CommManager::CreateGroupSync(const string &group, const vector(rank_id_list).data())); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); + auto mode = context_ptr->get_param(MS_CTX_EXECUTION_MODE); + if (!is_task_sink && mode == kGraphMode) { + HcclCollectiveGroup::instance().CreateCommGroup(group, rank_id_list); + } else { + HCCL_RUN_CHECK(string("create communicate group"), group, + hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size), + vector(rank_id_list).data())); + } return true; } @@ -80,7 +91,11 @@ bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); if (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { - HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id)); + if (!context->get_param(MS_CTX_ENABLE_TASK_SINK)) { + *rank_id = static_cast(HcclCollectiveGroup::instance().GetRankId(group)); + } else { + HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id)); + } } else { HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(rank_id)); } @@ -92,7 +107,12 @@ bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) cons auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); if (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { - HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size)); + if (!context->get_param(MS_CTX_ENABLE_TASK_SINK)) { + *rank_size = static_cast(HcclCollectiveGroup::instance().GetRankSize(group)); + } else { + HCCL_RUN_CHECK(string("get rank size"), group, + hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size)); + } } else { HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(rank_size)); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 1168c52a4de..a71320da66e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -463,6 +463,7 @@ constexpr auto kAttrMultiCallEnd = "multicall_end"; constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END"; constexpr auto kAttrHiddenSize = "hidden_size"; constexpr auto kAttrInputSize = "input_size"; +constexpr auto kAttrGraphOpRun = "GRAPH_OP_RUN"; // primal attr key name constexpr auto kPrimalAttrForwardNodeName = "forward_node_name"; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 404b87d6e45..3f3f670859e 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -580,6 +580,11 @@ BackendPtr CreateBackend() { if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { backend->set_is_multi_graph_sink(false); context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); + } else { + auto single_op = std::getenv(kAttrGraphOpRun); + if (single_op && std::string(single_op) == "1") { + context_ptr->set_param(MS_CTX_ENABLE_TASK_SINK, false); + } } } return backend; diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index e5c381be210..41aef4d9c04 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -19,6 +19,7 @@ from ._hccl_management import load_lib as hccl_load_lib _HCCL_AVAILABLE = False _NCCL_AVAILABLE = False +_MPI_AVAILABLE = False try: import mindspore._ms_mpi as mpi _NCCL_AVAILABLE = True @@ -34,6 +35,11 @@ except RuntimeError: if _HCCL_AVAILABLE: from . import _hccl_management as hccl + try: + import mindspore._ascend_mpi as mpi + _MPI_AVAILABLE = True + except ImportError: + _MPI_AVAILABLE = False else: try: import hccl_test.manage.api as hccl @@ -68,6 +74,7 @@ class Backend: UNDEFINED = "undefined" HCCL = "hccl" NCCL = "nccl" + HCCL_MPI = "hccl_mpi" def __new__(cls, name): """Create instance object of Backend.""" @@ -105,6 +112,15 @@ def is_hccl_available(): """ return _HCCL_AVAILABLE +def is_mpi_available(): + """ + Check hccl & mpi api is available. + + Returns: + Boolean. Return whether hccl & mpi is available or not. + """ + return _MPI_AVAILABLE + def is_nccl_available(): """ @@ -145,11 +161,13 @@ def check_parameter_available(func): backend = kargs.get("backend") if backend is Backend.HCCL and not is_hccl_available(): raise RuntimeError("Distributed Communication doesn't have HCCL built in") + if backend is Backend.HCCL_MPI and not is_mpi_available(): + raise RuntimeError("Distributed Communication doesn't have MPI built in") if backend is Backend.NCCL and not is_nccl_available(): raise RuntimeError("Distributed Communication doesn't have NCCL built in") if group is None: - if backend is Backend.HCCL: + if backend is Backend.HCCL or Backend.HCCL_MPI: group = HCCL_WORLD_COMM_GROUP elif backend is Backend.NCCL: group = NCCL_WORLD_COMM_GROUP @@ -176,7 +194,9 @@ def _get_rank_helper(group, backend): if _is_role_pserver() or _is_role_sched(): rank_id = 0 return rank_id - if backend == Backend.HCCL: + if backend == Backend.HCCL_MPI: + rank_id = mpi.get_rank_id(group) + elif backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: rank_id = hccl.get_rank_id() else: @@ -204,7 +224,9 @@ def _get_local_rank_helper(group, backend): Integer. The local rank id of the calling process. """ rank_id = None - if backend == Backend.HCCL: + if backend == Backend.HCCL_MPI: + rank_id = mpi.get_rank_id(group) + elif backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: rank_id = hccl.get_local_rank_id() else: @@ -235,7 +257,9 @@ def _get_size_helper(group, backend): if _is_role_pserver() or _is_role_sched(): size = 1 return size - if backend == Backend.HCCL: + if backend == Backend.HCCL_MPI: + size = mpi.get_rank_size(group) + elif backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: size = hccl.get_rank_size() else: @@ -360,6 +384,8 @@ def _create_group_helper(group, rank_ids, backend): if len(rank_ids) - len(list(set(rank_ids))) > 0: raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) hccl.create_group(group, rank_size, rank_ids) + elif backend == Backend.HCCL_MPI: + mpi.create_group(group, rank_ids) elif backend == Backend.NCCL: raise RuntimeError("Nccl doesn't support create_group now.") else: diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 65301dcac70..6c18ac97187 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -36,6 +36,22 @@ def _get_group(group): return GlobalComm.WORLD_COMM_GROUP return group +def _check_task_sink_envs(): + """ + Check whether task_sink environment variables have been exported or not. + + return True if task_sink environment variables have been exported, False otherwise. + """ + import os + task_sink = os.getenv("GRAPH_OP_RUN") + if task_sink: + try: + if int(task_sink) == 1: + return False + except ValueError: + return True + return True + def _check_parallel_envs(): """ @@ -86,7 +102,13 @@ def init(backend_name=None): """ if _is_role_pserver() or _is_role_sched(): return + task_sink = _check_task_sink_envs() device_target = context.get_context("device_target") + mode = context.get_context("mode") + mpi_init = False + if not task_sink and mode == context.GRAPH_MODE: + mpi_init = True + if backend_name is None: if device_target == "Ascend": backend_name = "hccl" @@ -101,9 +123,12 @@ def init(backend_name=None): if backend_name == "hccl": if device_target != "Ascend": raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target)) - _check_parallel_envs() + if not mpi_init: + _check_parallel_envs() + GlobalComm.BACKEND = Backend("hccl") + else: + GlobalComm.BACKEND = Backend("hccl_mpi") init_hccl() - GlobalComm.BACKEND = Backend("hccl") GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.INITED = True elif backend_name == "nccl": diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 4730432508c..9fb7397002f 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -43,6 +43,23 @@ def _get_pipeline_stages(): return auto_parallel_context().get_pipeline_stages() +def _check_task_sink_envs(): + """ + Check whether task_sink environment variables have been exported or not. + + return True if task_sink environment variables have been exported, False otherwise. + """ + import os + task_sink = os.getenv("SINGLE_OP_MODE") + if task_sink: + try: + if int(task_sink) == 1: + return False + except ValueError: + return True + return True + + def _check_full_batch(): """ full_batch could only be used under semi_auto_parallel or auto_parallel, check it. diff --git a/mindspore/train/model.py b/mindspore/train/model.py index d9e1640ff7e..a0dc92fde17 100644 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -26,7 +26,8 @@ from .._checkparam import check_input_data, check_output_data, Validator from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ - _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check + _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \ + _check_task_sink_envs from ..parallel._ps_context import _is_role_pserver, _is_role_sched from ..nn.metrics import Loss from .. import nn @@ -417,6 +418,14 @@ class Model: sink_size (int): Control the amount of data in each sink. Default: -1. """ epoch = Validator.check_positive_int(epoch) + if context.get_context("device_target") == "Ascend" and \ + context.get_context("mode") == context.GRAPH_MODE and not \ + _check_task_sink_envs() and \ + dataset_sink_mode: + dataset_sink_mode = False + logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode." + "So the training process will be performed with dataset not sink.") + if self._parameter_broadcast: self._train_network.set_broadcast_flag() @@ -830,6 +839,13 @@ class Model: dataset_sink_mode = False logger.warning("CPU cannot support dataset sink mode currently." "So the evaluating process will be performed with dataset non-sink mode.") + if context.get_context("device_target") == "Ascend" and \ + context.get_context("mode") == context.GRAPH_MODE and not \ + _check_task_sink_envs() and \ + dataset_sink_mode: + dataset_sink_mode = False + logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode." + "So the training process will be performed with dataset not sink.") with _CallbackManager(callbacks) as list_callback: if dataset_sink_mode: diff --git a/tests/st/nontask_sink/test_all.py b/tests/st/nontask_sink/test_all.py new file mode 100644 index 00000000000..64a622bced7 --- /dev/null +++ b/tests/st/nontask_sink/test_all.py @@ -0,0 +1,26 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest +from mindspore import context + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_single +def test_hccl_allreduce(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + return_code = os.system("mpirun --allow-run-as-root -n 8 pytest -s test_allreduce.py") + assert return_code == 0 diff --git a/tests/st/nontask_sink/test_allreduce.py b/tests/st/nontask_sink/test_allreduce.py new file mode 100644 index 00000000000..153ce394954 --- /dev/null +++ b/tests/st/nontask_sink/test_allreduce.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test hccl allreduce with 8p""" + +import os +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore.ops import operations as P +from mindspore.communication.management import init + +np.random.seed(1) +os.environ['GRAPH_OP_RUN'] = str(1) +os.environ['HCCL_WHITELIST_DISABLE'] = str(1) +init() + +class AllReduceNet(nn.Cell): + def __init__(self): + super(AllReduceNet, self).__init__() + self.mul = P.Mul() + self.all_reduce = P.AllReduce() + self.add = P.Add() + self.y1 = Tensor(np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])).astype(np.float32) + self.y2 = Tensor(np.array([[-16, -16, -16, -16], [-16, -16, -16, -16], \ + [-16, -16, -16, -16]])).astype(np.float32) + + def construct(self, x): + x = self.mul(x, 2) + z = self.add(x, self.y1) + z = self.all_reduce(z) + out = self.add(z, self.y2) + out = self.all_reduce(out) + out = self.mul(out, 2) + return out + +def test_hccl_allreduce_8p(): + net = AllReduceNet() + input_x = np.ones([3, 4]).astype(np.float32) + expect_output = [[256, 256, 256, 256], [256, 256, 256, 256], [256, 256, 256, 256]] + output = net(Tensor(input_x, mstype.float32)) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/nontask_sink/test_lenet.py b/tests/st/nontask_sink/test_lenet.py new file mode 100644 index 00000000000..9725916f548 --- /dev/null +++ b/tests/st/nontask_sink/test_lenet.py @@ -0,0 +1,179 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import time +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import context, Tensor, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.common.initializer import TruncatedNormal +from mindspore.nn.optim import Momentum +from mindspore.nn.wrap.cell_wrapper import WithLossCell +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P + +np.random.seed(1) +grad_by_list = C.GradOperation(get_by_list=True) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class LeNet(nn.Cell): + """ + Lenet network + Args: + num_class (int): Num classes, Default: 10. + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + """ + + def __init__(self, num_class=10): + super(LeNet, self).__init__() + self.num_class = num_class + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class CrossEntropyLoss(nn.Cell): + """ + Define loss for network + """ + + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.num = Tensor(32.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss = self.cross_entropy(logits, label)[0] + loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num) + return loss + + +class GradWrap(nn.Cell): + """ + GradWrap definition + """ + + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return grad_by_list(self.network, weights)(x, label) + + +def test_ascend_lenet(): + epoch_size = 20 + batch_size = 32 + inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) + labels = Tensor(np.ones([batch_size]).astype(np.int32)) + + net = LeNet() + criterion = CrossEntropyLoss() + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrap(net_with_criterion) + train_network.set_train() + total_time = 0 + + for epoch in range(0, epoch_size): + start_time = time.time() + fw_output = net(inputs) + loss_output = criterion(fw_output, labels) + grads = train_network(inputs, labels) + optimizer(grads) + end_time = time.time() + cost_time = end_time - start_time + total_time = total_time + cost_time + + print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) + return loss_output + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_lenet1(): + os.environ['GRAPH_OP_RUN'] = str(1) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_lenet2(): + os.environ['GRAPH_OP_RUN'] = str(1) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 diff --git a/tests/st/pynative/data_parallel/test_pynative_hccl.py b/tests/st/pynative/data_parallel/test_pynative_hccl.py index 12b935527cb..5b8f1241aef 100644 --- a/tests/st/pynative/data_parallel/test_pynative_hccl.py +++ b/tests/st/pynative/data_parallel/test_pynative_hccl.py @@ -86,3 +86,33 @@ def test_pynative_hccl_8p(): os.system("rm -rf " + str(i)) print("End training...") + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_single +def test_pynative_hccl_8pv2(): + os.environ['GRAPH_OP_RUN'] = str(1) + device_num = 8 + process = [] + q = Queue() + for i in range(device_num): + device_id = i + process.append(Process(target=train_allreduce_8p, args=(q, device_id, device_num))) + + for i in range(device_num): + process[i].start() + + print("Waiting for all subprocesses done...") + + for i in range(device_num): + process[i].join() + + # check result + for i in range(device_num): + assert q.get() + + for i in range(device_num): + os.system("rm -rf " + str(i)) + + print("End training...") diff --git a/tests/ut/cpp/stub/hccl/collective_stub.cc b/tests/ut/cpp/stub/hccl/collective_stub.cc index 06872fedd1b..1dad2d51b60 100644 --- a/tests/ut/cpp/stub/hccl/collective_stub.cc +++ b/tests/ut/cpp/stub/hccl/collective_stub.cc @@ -28,6 +28,7 @@ int HcclCollectiveGroup::GetRankSize(const std::string &) const { return 0; } int HcclCollectiveGroup::GetRankId(const std::string &) const { return 0; } int HcclCollectiveGroup::GetDeviceId() const { return 0; } void HcclCollectiveGroup::CreateCommGroup(const std::string &, const std::vector &) { return; } +void HcclCollectiveGroup::FinalizeCollective() { return; } } // namespace collective } // namespace ascend } // namespace device