!22173 ascend add nontask sink mode

Merge pull request !22173 from baihuawei/graph_mode_nonsink_part3-2
This commit is contained in:
i-robot 2021-08-26 06:20:36 +00:00 committed by Gitee
commit 8e39dd4ec7
19 changed files with 469 additions and 12 deletions

View File

@ -391,6 +391,7 @@ checkopts()
ENABLE_D="on"
ENABLE_ACL="on"
ENABLE_CPU="on"
ENABLE_MPI="on"
else
echo "Invalid value ${DEVICE_VERSION} for option -V"
usage

View File

@ -163,6 +163,13 @@ if(ENABLE_MPI)
COMPONENT mindspore
)
endif()
if(ENABLE_D)
install(
TARGETS _ascend_mpi
DESTINATION ${INSTALL_BASE_DIR}
COMPONENT mindspore
)
endif()
endif()
if(ENABLE_GPU)
@ -180,6 +187,16 @@ if(ENABLE_GPU)
)
endif()
if(ENABLE_D)
if(ENABLE_MPI)
install(
TARGETS ascend_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
endif()
if(ENABLE_CPU AND NOT WIN32)
install(
TARGETS ps_cache

View File

@ -420,4 +420,10 @@ if(MODE_ASCEND_ALL)
target_link_libraries(_c_expression PRIVATE ${adump_server})
endif()
if(ENABLE_D)
if(ENABLE_MPI)
set_target_properties(_ascend_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
endif()
endif()
add_subdirectory(cxx_api)

View File

@ -72,6 +72,7 @@
#include "transform/graph_ir/df_graph_manager.h"
#include "transform/graph_ir/op_adapter_map.h"
#include "runtime/device/ascend/profiling/profiling_manager.h"
#include "runtime/device/ascend/distribute/ascend_collective.h"
#endif
#ifdef ENABLE_DUMP_IR
#include "debug/rdr/running_data_recorder.h"
@ -91,6 +92,7 @@ using mindspore::abstract::AbstractTuplePtr;
#ifdef ENABLE_D
using mindspore::device::ascend::ProfilingManager;
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
#endif
const char IR_TYPE_ANF[] = "anf_ir";
@ -1216,6 +1218,23 @@ void InitHccl() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
#if ENABLE_D
bool task_sink = true;
auto single_op = std::getenv(kAttrGraphOpRun);
if (single_op && std::string(single_op) == "1") {
task_sink = false;
}
auto mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!task_sink && mode == kGraphMode) {
MS_LOG(INFO) << "mpi collective init.";
if (!HcclCollectiveGroup::instance().InitCollective()) {
MS_LOG(EXCEPTION) << "HcclCollectiveGroup init failed.";
}
device_id = IntToUint(HcclCollectiveGroup::instance().GetDeviceId());
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
}
#endif
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
if (ms_context->backend_policy() == "ms" &&

View File

@ -268,6 +268,10 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
(void)ResetDevice(device_id);
(void)ProfilingManager::GetInstance().StopProfiling();
current_graph_ = nullptr;
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
HcclCollectiveGroup::instance().FinalizeCollective();
}
MS_LOG(INFO) << "Ascend finalize end";
}

View File

@ -16,6 +16,9 @@
#include "runtime/device/ascend/distribute/mpi_pycc.h"
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <vector>
namespace mindspore {
namespace device {
namespace ascend {
@ -28,12 +31,16 @@ MpiPycc &MpiPycc::instance() {
int MpiPycc::GetDeviceID() { return GetDeviceId(); }
int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); }
int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); }
void MpiPycc::CreateGroup(const std::string &group, const std::vector<unsigned int> &ranks) {
CreateCommForGroup(group, ranks);
}
// cppcheck-suppress syntaxError
PYBIND11_MODULE(_ascend_mpi, mpi_initializer) {
mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id");
mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id");
mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size");
mpi_initializer.def("create_group", &MpiPycc::CreateGroup, "create group");
}
} // namespace collective
} // namespace ascend

View File

@ -18,6 +18,7 @@
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H
#include <string>
#include <vector>
#include "runtime/device/ascend/distribute/collective_group_wrapper.h"
namespace mindspore {
@ -32,6 +33,7 @@ class MpiPycc {
static int GetDeviceID();
static int GetRankId(const std::string &group);
static int GetRankSize(const std::string &group);
static void CreateGroup(const std::string &group, const std::vector<unsigned int> &ranks);
private:
MpiPycc() = default;

View File

@ -22,6 +22,9 @@
#ifndef NO_DLIB
#include "runtime/hccl_adapter/hccl_adapter.h"
#include "hccl/hcom.h"
#include "runtime/device/ascend/distribute/ascend_collective.h"
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
#endif
#if defined(ENABLE_GPU)
@ -69,9 +72,17 @@ bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int
auto rank_size = rank_id_list.size();
HCCL_GROUP_CHECK_EMPTY(group);
HCCL_GROUP_CHECK_IS_WORLD(group);
HCCL_RUN_CHECK(string("create communicate group"), group,
hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size),
vector<unsigned int>(rank_id_list).data()));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!is_task_sink && mode == kGraphMode) {
HcclCollectiveGroup::instance().CreateCommGroup(group, rank_id_list);
} else {
HCCL_RUN_CHECK(string("create communicate group"), group,
hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size),
vector<unsigned int>(rank_id_list).data()));
}
return true;
}
@ -80,7 +91,11 @@ bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id));
if (!context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
*rank_id = static_cast<unsigned int>(HcclCollectiveGroup::instance().GetRankId(group));
} else {
HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id));
}
} else {
HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(rank_id));
}
@ -92,7 +107,12 @@ bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) cons
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size));
if (!context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
*rank_size = static_cast<unsigned int>(HcclCollectiveGroup::instance().GetRankSize(group));
} else {
HCCL_RUN_CHECK(string("get rank size"), group,
hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size));
}
} else {
HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(rank_size));
}

View File

@ -463,6 +463,7 @@ constexpr auto kAttrMultiCallEnd = "multicall_end";
constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
constexpr auto kAttrHiddenSize = "hidden_size";
constexpr auto kAttrInputSize = "input_size";
constexpr auto kAttrGraphOpRun = "GRAPH_OP_RUN";
// primal attr key name
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";

View File

@ -580,6 +580,11 @@ BackendPtr CreateBackend() {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
backend->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
} else {
auto single_op = std::getenv(kAttrGraphOpRun);
if (single_op && std::string(single_op) == "1") {
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
}
}
}
return backend;

View File

@ -19,6 +19,7 @@ from ._hccl_management import load_lib as hccl_load_lib
_HCCL_AVAILABLE = False
_NCCL_AVAILABLE = False
_MPI_AVAILABLE = False
try:
import mindspore._ms_mpi as mpi
_NCCL_AVAILABLE = True
@ -34,6 +35,11 @@ except RuntimeError:
if _HCCL_AVAILABLE:
from . import _hccl_management as hccl
try:
import mindspore._ascend_mpi as mpi
_MPI_AVAILABLE = True
except ImportError:
_MPI_AVAILABLE = False
else:
try:
import hccl_test.manage.api as hccl
@ -68,6 +74,7 @@ class Backend:
UNDEFINED = "undefined"
HCCL = "hccl"
NCCL = "nccl"
HCCL_MPI = "hccl_mpi"
def __new__(cls, name):
"""Create instance object of Backend."""
@ -105,6 +112,15 @@ def is_hccl_available():
"""
return _HCCL_AVAILABLE
def is_mpi_available():
"""
Check hccl & mpi api is available.
Returns:
Boolean. Return whether hccl & mpi is available or not.
"""
return _MPI_AVAILABLE
def is_nccl_available():
"""
@ -145,11 +161,13 @@ def check_parameter_available(func):
backend = kargs.get("backend")
if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available():
raise RuntimeError("Distributed Communication doesn't have MPI built in")
if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in")
if group is None:
if backend is Backend.HCCL:
if backend is Backend.HCCL or Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP
@ -176,7 +194,9 @@ def _get_rank_helper(group, backend):
if _is_role_pserver() or _is_role_sched():
rank_id = 0
return rank_id
if backend == Backend.HCCL:
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
rank_id = hccl.get_rank_id()
else:
@ -204,7 +224,9 @@ def _get_local_rank_helper(group, backend):
Integer. The local rank id of the calling process.
"""
rank_id = None
if backend == Backend.HCCL:
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
rank_id = hccl.get_local_rank_id()
else:
@ -235,7 +257,9 @@ def _get_size_helper(group, backend):
if _is_role_pserver() or _is_role_sched():
size = 1
return size
if backend == Backend.HCCL:
if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group)
elif backend == Backend.HCCL:
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_rank_size()
else:
@ -360,6 +384,8 @@ def _create_group_helper(group, rank_ids, backend):
if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI:
mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support create_group now.")
else:

View File

@ -36,6 +36,22 @@ def _get_group(group):
return GlobalComm.WORLD_COMM_GROUP
return group
def _check_task_sink_envs():
"""
Check whether task_sink environment variables have been exported or not.
return True if task_sink environment variables have been exported, False otherwise.
"""
import os
task_sink = os.getenv("GRAPH_OP_RUN")
if task_sink:
try:
if int(task_sink) == 1:
return False
except ValueError:
return True
return True
def _check_parallel_envs():
"""
@ -86,7 +102,13 @@ def init(backend_name=None):
"""
if _is_role_pserver() or _is_role_sched():
return
task_sink = _check_task_sink_envs()
device_target = context.get_context("device_target")
mode = context.get_context("mode")
mpi_init = False
if not task_sink and mode == context.GRAPH_MODE:
mpi_init = True
if backend_name is None:
if device_target == "Ascend":
backend_name = "hccl"
@ -101,9 +123,12 @@ def init(backend_name=None):
if backend_name == "hccl":
if device_target != "Ascend":
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
_check_parallel_envs()
if not mpi_init:
_check_parallel_envs()
GlobalComm.BACKEND = Backend("hccl")
else:
GlobalComm.BACKEND = Backend("hccl_mpi")
init_hccl()
GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl":

View File

@ -43,6 +43,23 @@ def _get_pipeline_stages():
return auto_parallel_context().get_pipeline_stages()
def _check_task_sink_envs():
"""
Check whether task_sink environment variables have been exported or not.
return True if task_sink environment variables have been exported, False otherwise.
"""
import os
task_sink = os.getenv("SINGLE_OP_MODE")
if task_sink:
try:
if int(task_sink) == 1:
return False
except ValueError:
return True
return True
def _check_full_batch():
"""
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.

View File

@ -26,7 +26,8 @@ from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
_check_task_sink_envs
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
@ -417,6 +418,14 @@ class Model:
sink_size (int): Control the amount of data in each sink. Default: -1.
"""
epoch = Validator.check_positive_int(epoch)
if context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and not \
_check_task_sink_envs() and \
dataset_sink_mode:
dataset_sink_mode = False
logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode."
"So the training process will be performed with dataset not sink.")
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
@ -830,6 +839,13 @@ class Model:
dataset_sink_mode = False
logger.warning("CPU cannot support dataset sink mode currently."
"So the evaluating process will be performed with dataset non-sink mode.")
if context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and not \
_check_task_sink_envs() and \
dataset_sink_mode:
dataset_sink_mode = False
logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode."
"So the training process will be performed with dataset not sink.")
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
from mindspore import context
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_hccl_allreduce():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
return_code = os.system("mpirun --allow-run-as-root -n 8 pytest -s test_allreduce.py")
assert return_code == 0

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test hccl allreduce with 8p"""
import os
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.communication.management import init
np.random.seed(1)
os.environ['GRAPH_OP_RUN'] = str(1)
os.environ['HCCL_WHITELIST_DISABLE'] = str(1)
init()
class AllReduceNet(nn.Cell):
def __init__(self):
super(AllReduceNet, self).__init__()
self.mul = P.Mul()
self.all_reduce = P.AllReduce()
self.add = P.Add()
self.y1 = Tensor(np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])).astype(np.float32)
self.y2 = Tensor(np.array([[-16, -16, -16, -16], [-16, -16, -16, -16], \
[-16, -16, -16, -16]])).astype(np.float32)
def construct(self, x):
x = self.mul(x, 2)
z = self.add(x, self.y1)
z = self.all_reduce(z)
out = self.add(z, self.y2)
out = self.all_reduce(out)
out = self.mul(out, 2)
return out
def test_hccl_allreduce_8p():
net = AllReduceNet()
input_x = np.ones([3, 4]).astype(np.float32)
expect_output = [[256, 256, 256, 256], [256, 256, 256, 256], [256, 256, 256, 256]]
output = net(Tensor(input_x, mstype.float32))
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,179 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import time
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context, Tensor, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn.optim import Momentum
from mindspore.nn.wrap.cell_wrapper import WithLossCell
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
np.random.seed(1)
grad_by_list = C.GradOperation(get_by_list=True)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
class LeNet(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes, Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet, self).__init__()
self.num_class = num_class
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class CrossEntropyLoss(nn.Cell):
"""
Define loss for network
"""
def __init__(self):
super(CrossEntropyLoss, self).__init__()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.num = Tensor(32.0, mstype.float32)
def construct(self, logits, label):
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss = self.cross_entropy(logits, label)[0]
loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num)
return loss
class GradWrap(nn.Cell):
"""
GradWrap definition
"""
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x, label):
weights = self.weights
return grad_by_list(self.network, weights)(x, label)
def test_ascend_lenet():
epoch_size = 20
batch_size = 32
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
labels = Tensor(np.ones([batch_size]).astype(np.int32))
net = LeNet()
criterion = CrossEntropyLoss()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrap(net_with_criterion)
train_network.set_train()
total_time = 0
for epoch in range(0, epoch_size):
start_time = time.time()
fw_output = net(inputs)
loss_output = criterion(fw_output, labels)
grads = train_network(inputs, labels)
optimizer(grads)
end_time = time.time()
cost_time = end_time - start_time
total_time = total_time + cost_time
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
return loss_output
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_lenet1():
os.environ['GRAPH_OP_RUN'] = str(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
loss_output = test_ascend_lenet()
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_lenet2():
os.environ['GRAPH_OP_RUN'] = str(1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
loss_output = test_ascend_lenet()
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003

View File

@ -86,3 +86,33 @@ def test_pynative_hccl_8p():
os.system("rm -rf " + str(i))
print("End training...")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_pynative_hccl_8pv2():
os.environ['GRAPH_OP_RUN'] = str(1)
device_num = 8
process = []
q = Queue()
for i in range(device_num):
device_id = i
process.append(Process(target=train_allreduce_8p, args=(q, device_id, device_num)))
for i in range(device_num):
process[i].start()
print("Waiting for all subprocesses done...")
for i in range(device_num):
process[i].join()
# check result
for i in range(device_num):
assert q.get()
for i in range(device_num):
os.system("rm -rf " + str(i))
print("End training...")

View File

@ -28,6 +28,7 @@ int HcclCollectiveGroup::GetRankSize(const std::string &) const { return 0; }
int HcclCollectiveGroup::GetRankId(const std::string &) const { return 0; }
int HcclCollectiveGroup::GetDeviceId() const { return 0; }
void HcclCollectiveGroup::CreateCommGroup(const std::string &, const std::vector<unsigned int> &) { return; }
void HcclCollectiveGroup::FinalizeCollective() { return; }
} // namespace collective
} // namespace ascend
} // namespace device