Add _bind_device_context api

This commit is contained in:
caifubi 2023-02-13 20:34:06 +08:00
parent 447c73b416
commit 8d91000807
17 changed files with 87 additions and 17 deletions

View File

@ -554,4 +554,5 @@ PYBIND11_MODULE(_c_expression, m) {
"Check whether a mindir file can be loaded and up to date.");
#endif
(void)m.def("_ms_memory_recycle", &mindspore::pipeline::MemoryRecycle, "Recycle memory used by mindspore.");
(void)m.def("_bind_device_ctx", &mindspore::pipeline::BindDeviceCtx, "Bind device context to current thread");
}

View File

@ -1850,6 +1850,8 @@ void MemoryRecycle() {
FuncGraphLoopBreaker::Inst().BreakLoop();
}
void BindDeviceCtx() { device::DeviceContextManager::GetInstance().BindDeviceCtx(); }
void ClearResPart1() {
runtime::OpExecutor::GetInstance().WorkerJoin();
// When the python process exits, the kernels on the device may not have finished executing.

View File

@ -183,6 +183,7 @@ void FinalizeBackend();
void ME_EXPORT ClearResAtexit();
void CloseTsd(bool force = false);
void MemoryRecycle();
void BindDeviceCtx();
FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
const std::string &dec_mode, const py::object decrypt = py::none(),

View File

@ -164,7 +164,7 @@ void AscendDeviceAddress::BindDevice() const {
if (!device_name_.empty()) {
auto ascend_device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(ascend_device_context);
if (!ascend_device_context->device_res_manager_->BindDeviceToCurrentThread()) {
if (!ascend_device_context->device_res_manager_->BindDeviceToCurrentThread(false)) {
MS_LOG(WARNING) << "Bind device to current thread failed.";
}
} else {

View File

@ -65,7 +65,7 @@ void AscendDeviceResManager::Destroy() {
MS_LOG(INFO) << "Device resource manager Destroy success.";
}
bool AscendDeviceResManager::BindDeviceToCurrentThread() const {
bool AscendDeviceResManager::BindDeviceToCurrentThread(bool /* force_bind */) const {
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetContext();
return true;

View File

@ -42,7 +42,7 @@ class AscendDeviceResManager : public DeviceResManager {
void Destroy() override;
// set rt_context_ to this thread to control device
bool BindDeviceToCurrentThread() const override;
bool BindDeviceToCurrentThread(bool /* force_bind */) const override;
// Relevant function to allocate and free device memory of raw ptr.
void *AllocateMemory(size_t size) const override;

View File

@ -366,7 +366,7 @@ bool AscendKernelExecutor::LaunchKernel(const CNodePtr &kernel, const vector<Add
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
KernelType kernel_type = AnfAlgo::GetKernelType(kernel);
MS_EXCEPTION_IF_NULL(kernel);
(void)res_manager_->BindDeviceToCurrentThread();
(void)res_manager_->BindDeviceToCurrentThread(false);
std::vector<AddressPtr> real_inputs;
bool ret = GetKernelRealInputs(kernel, inputs, &real_inputs);

View File

@ -93,7 +93,7 @@ bool GPUDeviceAddress::SyncHostToDevice(size_t size, const void *host_ptr) const
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
auto gpu_device_context = dynamic_cast<GPUDeviceContext *>(device_context);
MS_EXCEPTION_IF_NULL(gpu_device_context);
if (!gpu_device_context->device_res_manager_->BindDeviceToCurrentThread()) {
if (!gpu_device_context->device_res_manager_->BindDeviceToCurrentThread(false)) {
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
}
}

View File

@ -90,7 +90,7 @@ static thread_local bool cur_thread_device_inited{false};
void GPUDeviceContext::Initialize() {
if (initialized_) {
if (!device_res_manager_->BindDeviceToCurrentThread()) {
if (!device_res_manager_->BindDeviceToCurrentThread(false)) {
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
}
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
@ -222,7 +222,7 @@ void GPUDeviceContext::Destroy() {
void *GPUDeviceResManager::AllocateMemory(size_t size) const {
MS_EXCEPTION_IF_NULL(mem_manager_);
if (!BindDeviceToCurrentThread()) {
if (!BindDeviceToCurrentThread(false)) {
return nullptr;
}
return mem_manager_->MallocMemFromMemPool(size, false);
@ -247,7 +247,7 @@ bool GPUDeviceResManager::AllocateMemory(DeviceAddress *const &address) const {
return false;
}
if (!BindDeviceToCurrentThread()) {
if (!BindDeviceToCurrentThread(false)) {
return false;
}
if (auto_mem_offload_ != nullptr) {
@ -264,7 +264,7 @@ bool GPUDeviceResManager::AllocateMemory(DeviceAddress *const &address) const {
}
std::vector<void *> GPUDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
if (!BindDeviceToCurrentThread()) {
if (!BindDeviceToCurrentThread(false)) {
std::vector<void *> ptr_list;
return ptr_list;
}
@ -704,7 +704,7 @@ bool GPUKernelExecutor::LaunchKernel(const CNodePtr &kernel, const std::vector<A
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
size_t stream_id) const {
MS_EXCEPTION_IF_NULL(kernel);
if (!res_manager_->BindDeviceToCurrentThread()) {
if (!res_manager_->BindDeviceToCurrentThread(false)) {
return false;
}
bool ret = true;
@ -856,8 +856,8 @@ bool GPUDeviceResManager::LoadCollectiveCommLib() {
#endif
}
bool GPUDeviceResManager::BindDeviceToCurrentThread() const {
if (cur_thread_device_inited) {
bool GPUDeviceResManager::BindDeviceToCurrentThread(bool force_bind) const {
if (cur_thread_device_inited && !force_bind) {
return true;
}

View File

@ -41,7 +41,7 @@ class GPUDeviceResManager : public DeviceResManager {
// Release device memory, stream, cudnn and cublas handle, etc.
void Destroy() override;
bool BindDeviceToCurrentThread() const override;
bool BindDeviceToCurrentThread(bool force_bind) const override;
std::shared_ptr<void> AllocateHostMemory(size_t size) const override;

View File

@ -60,7 +60,7 @@ void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
// Launch custom func
MS_EXCEPTION_IF_NULL(node);
auto custom_func = AnfUtils::GetCustomFunc(node);
if (!device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread()) {
if (!device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread(false)) {
std::string error_info = "BindDevice to current thread failed: " + node->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
}

View File

@ -232,7 +232,7 @@ void EmbeddingCachePrefetchActor::Run() {
// Bind device to current thread to gain device control privileges
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
if (!device_context_->device_res_manager_->BindDeviceToCurrentThread()) {
if (!device_context_->device_res_manager_->BindDeviceToCurrentThread(false)) {
MS_LOG(ERROR) << "Failed to bind device to current thread.";
running_ = false;
PsDataPrefetch::GetInstance().NotifyFinalize();

View File

@ -103,7 +103,9 @@ class BACKEND_EXPORT DeviceResManager {
virtual void Destroy() {}
// Bind device to current thread to gain device control privileges
virtual bool BindDeviceToCurrentThread() const { return true; }
// If force_bind is true, bind context to current thread every time;
// Otherwise, only bind context to current thread for the first time.
virtual bool BindDeviceToCurrentThread(bool force_bind) const { return true; }
// Relevant function to allocate and free device memory of raw ptr.
virtual void *AllocateMemory(size_t size) const = 0;

View File

@ -462,6 +462,14 @@ void DeviceContextManager::ClearDeviceContexts() {
device_contexts_.clear();
}
void DeviceContextManager::BindDeviceCtx() const {
for (auto &iter : device_contexts_) {
MS_EXCEPTION_IF_NULL(iter.second);
MS_EXCEPTION_IF_NULL(iter.second->device_res_manager_);
iter.second->device_res_manager_->BindDeviceToCurrentThread(true);
}
}
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key,
string jit_level /* ="" */) {
std::string device_context_key_str = device_context_key.ToString();

View File

@ -57,6 +57,7 @@ class BACKEND_EXPORT DeviceContextManager {
void WaitTaskFinishOnDevice() const;
void UnloadPlugin();
std::string GetErrorMsg() const;
void BindDeviceCtx() const;
private:
DeviceContextManager() = default;

View File

@ -38,7 +38,7 @@ from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, \
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
_ms_memory_recycle
_ms_memory_recycle, _bind_device_ctx
from mindspore.parallel._ps_context import _is_role_sched
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
_get_pipeline_stages, _is_in_auto_parallel_mode
@ -1566,6 +1566,10 @@ def _generate_pair_password(obf_password):
return obf_password, append_password
def _bind_device_context():
"""Bind device context to current thread"""
_bind_device_ctx()
_cell_graph_executor = _CellGraphExecutor()
_pynative_executor = _PyNativeExecutor()

View File

@ -0,0 +1,51 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.api import _bind_device_context
class Add(nn.Cell):
def __init__(self):
super(Add, self).__init__()
self.add = P.Add()
def construct(self, x1, x2):
return self.add(x1, x2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bind_cuda_ctx_api():
"""
Feature: _bind_device_ctx api
Description: Test _bind_device_ctx api.
Expectation: The results are as expected
"""
x1 = Tensor(np.array([1]))
x2 = Tensor(np.array([2]))
_bind_device_context()
net = Add()
output = net(x1, x2)
assert output.asnumpy() == np.array([3])
_bind_device_context()
output = net(x1, output)
assert output.asnumpy() == np.array([4])