forked from mindspore-Ecosystem/mindspore
!33524 LockRuntime to stream level
Merge pull request !33524 from TuDouNi/stream_lock
This commit is contained in:
commit
42141e991b
|
@ -1114,6 +1114,8 @@ bool AscendKernelRuntime::SyncStream() {
|
|||
SetCurrentContext();
|
||||
session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
|
||||
for (auto &iter : stream_id_map_) {
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime(iter.second);
|
||||
if (rtStreamSynchronize(iter.second) != RT_ERROR_NONE) { // o for switch stream
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
return false;
|
||||
|
@ -1145,6 +1147,8 @@ bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size,
|
|||
MS_LOG(ERROR) << "rtMemcpyAsync size is 0, copy kind:" << kind;
|
||||
return false;
|
||||
}
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
||||
if (RT_ERROR_NONE != rtMemcpyAsync(dst, size, src, size, static_cast<rtMemcpyKind_t>(kind), stream_)) {
|
||||
MS_LOG(ERROR) << "Call runtime rtMemcpyAsync error.";
|
||||
return false;
|
||||
|
|
|
@ -161,7 +161,7 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_
|
||||
<< ", args_size:" << args_.length();
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
||||
rtArgsEx_t argsInfo = {};
|
||||
argsInfo.args = args_.data();
|
||||
argsInfo.argsSize = static_cast<uint32_t>(args_.length());
|
||||
|
|
|
@ -160,7 +160,7 @@ bool DynamicAicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, cons
|
|||
flag = RT_KERNEL_CUSTOM_AICPU;
|
||||
}
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
||||
rtArgsEx_t argsInfo = {};
|
||||
argsInfo.args = args_.data();
|
||||
argsInfo.argsSize = static_cast<uint32_t>(args_.length());
|
||||
|
@ -200,7 +200,7 @@ void DynamicAicpuOpKernelMod::Wait() {
|
|||
return;
|
||||
}
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
||||
auto ret = rtStreamSynchronize(stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call runtime rtStreamSynchronize failed. Op name: " << cnode->fullname_with_scope();
|
||||
|
|
|
@ -176,8 +176,6 @@ size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64
|
|||
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto ret = runtime_instance->SyncStream();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Sync stream error!";
|
||||
|
|
|
@ -60,8 +60,6 @@ void TensorShapeKernelMod::Execute() {
|
|||
} else {
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto ret = runtime_instance->SyncStream();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Sync stream error!";
|
||||
|
|
|
@ -150,8 +150,6 @@ void ReshapeKernelMod::Execute() {
|
|||
MS_LOG(EXCEPTION) << "Execute ReshapeKernel memcpy_s failed";
|
||||
}
|
||||
} else {
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
if (!output_addr->AsyncDeviceToDevice(output_shapes, input_size_byte, address_x->type_id(), address_x->GetPtr(),
|
||||
address_x->format())) {
|
||||
MS_LOG(EXCEPTION) << "Host Reshape sync device to device failed.";
|
||||
|
@ -179,7 +177,7 @@ void ReshapeKernelMod::Execute(const std::vector<AddressPtr> &inputs, const std:
|
|||
|
||||
size_t input_size_byte = LongToSize(GetArrProd(cnode)) * abstract::TypeIdSize(type_x);
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
||||
auto status =
|
||||
rtMemcpyAsync(output_addr, outputs[0]->size, address_x, input_size_byte, RT_MEMCPY_DEVICE_TO_DEVICE, stream_);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
|
|
|
@ -187,7 +187,7 @@ bool DynamicTbeKernelMod::CopyTilingToDevice(void *stream_ptr) {
|
|||
return true;
|
||||
}
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
|
||||
auto ret = aclrtMemcpyAsync(tiling_data_ptr_, op_para_size, tiling_data_.c_str(), tiling_data_.size(),
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, stream_ptr);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
|
@ -239,7 +239,7 @@ bool DynamicTbeKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
// Skip reduce if axis is a empty Tensor (shape = 0)
|
||||
MS_LOG(INFO) << "The node " << cnode->fullname_with_scope() << "Need Skip.";
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
|
||||
rtError_t status = aclrtMemcpyAsync(outputs[0]->addr, inputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
|
@ -277,7 +277,7 @@ bool DynamicTbeKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
const auto dev_func = std::to_string(tiling_key_);
|
||||
const auto kernel_info = node_info + "/" + std::to_string(tiling_key_);
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
|
||||
rtArgsEx_t args_info = {};
|
||||
args_info.args = runtimeargs.data();
|
||||
args_info.argsSize = args_size;
|
||||
|
@ -289,7 +289,7 @@ bool DynamicTbeKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
}
|
||||
} else {
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
|
||||
auto ret = rtKernelLaunch(func_stub_, block_dim_, runtimeargs.data(), args_size, l2ctrl, stream_ptr);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error. Node info: " << node_info;
|
||||
|
|
|
@ -76,7 +76,7 @@ bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inpu
|
|||
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||
const void *stubFunc = reinterpret_cast<void *>(func_stub);
|
||||
auto argsSize = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtimeargs.size());
|
||||
auto lock = device::KernelRuntime::LockRuntime();
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
||||
auto ret = rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
|
||||
|
|
|
@ -77,9 +77,11 @@ KernelRuntime::~KernelRuntime() {
|
|||
communication_stream_ = nullptr;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> KernelRuntime::LockRuntime() {
|
||||
static std::mutex mutex;
|
||||
return std::lock_guard<std::mutex>(mutex);
|
||||
std::lock_guard<std::mutex> KernelRuntime::LockRuntime(const void *stream) {
|
||||
static std::mutex mu_;
|
||||
static mindspore::HashMap<const void *, std::mutex> mu_for_streams_;
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
return std::lock_guard<std::mutex>(mu_for_streams_[stream]);
|
||||
}
|
||||
|
||||
bool KernelRuntime::Load(const session::KernelGraph &, bool) {
|
||||
|
|
|
@ -139,7 +139,7 @@ class KernelRuntime {
|
|||
void AssignDynamicMemory(const session::KernelGraph &graph);
|
||||
|
||||
// lock runtime
|
||||
static std::lock_guard<std::mutex> LockRuntime();
|
||||
static std::lock_guard<std::mutex> LockRuntime(const void *stream);
|
||||
|
||||
protected:
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
|
Loading…
Reference in New Issue