!30029 remove mutil_stream for kernel by kernel

Merge pull request !30029 from baihuawei/seq_kernel_by_kernel
This commit is contained in:
i-robot 2022-02-19 06:59:10 +00:00 committed by Gitee
commit 0acbf5d67f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 4 additions and 35 deletions

View File

@ -386,7 +386,6 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
} else {
PreprocessBeforeRunSingleOpGraph(graph);
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
GenKernelEvents(NOT_NULL(graph));
}
} catch (const std::exception &e) {
ReportErrorMessage();
@ -711,7 +710,7 @@ bool AscendDeviceContext::MemoryCopyAsync(const CNodePtr &node, const vector<Add
}
aclError status = aclrtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
ACL_MEMCPY_DEVICE_TO_DEVICE, GetKernelStream(node));
ACL_MEMCPY_DEVICE_TO_DEVICE, compute_stream_);
if (status != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "MemCpyAsync op aclrtMemcpyAsync failed, ret:" << status;
return false;
@ -770,8 +769,6 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
BindDeviceToCurrentThread();
auto event_funcs = runtime_instance_->GetKernelEventFuncs(kernel);
std::vector<AddressPtr> real_inputs;
bool ret = GetKernelRealInputs(kernel, inputs, &real_inputs);
if (!ret) {
@ -784,12 +781,6 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
// start launch
std::lock_guard<std::mutex> locker(launch_mutex_);
// launch pre events
MS_LOG(DEBUG) << "Launch pre-events for kernel " << kernel->fullname_with_scope();
for (auto &pre_event_func : event_funcs.first) {
pre_event_func();
}
// launch atomic clean
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
@ -812,7 +803,7 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
dynamic_kernel->Execute();
dynamic_kernel->PostExecute();
} else {
ret = kernel_mod->Launch(real_inputs, workspace, outputs, GetKernelStream(kernel));
ret = kernel_mod->Launch(real_inputs, workspace, outputs, compute_stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed, kernel full name: " << kernel->fullname_with_scope();
return false;
@ -820,12 +811,6 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
}
}
// launch post event
MS_LOG(DEBUG) << "Launch post-events for kernel " << kernel->fullname_with_scope();
for (auto &post_event_func : event_funcs.second) {
post_event_func();
}
return PySyncRuning();
}
@ -872,7 +857,7 @@ bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vec
// Launch Atomic Node
auto kernel_mod = AnfAlgo::GetKernelMod(atomic_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(node));
return kernel_mod->Launch(atomic_inputs, {}, {}, compute_stream_);
}
void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {

View File

@ -30,8 +30,7 @@ from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
_check_task_sink_envs
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
@ -496,14 +495,6 @@ class Model:
sink_size (int): Control the amount of data in each sink. Default: -1.
"""
epoch = Validator.check_positive_int(epoch)
if context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and not \
_check_task_sink_envs() and \
dataset_sink_mode:
dataset_sink_mode = False
logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode."
"So the training process will be performed with dataset not sink.")
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
@ -954,13 +945,6 @@ class Model:
dataset_sink_mode = False
logger.warning("CPU cannot support dataset sink mode currently."
"So the evaluating process will be performed with dataset non-sink mode.")
if context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and not \
_check_task_sink_envs() and \
dataset_sink_mode:
dataset_sink_mode = False
logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode."
"So the training process will be performed with dataset not sink.")
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode: