forked from mindspore-Ecosystem/mindspore
Add _bind_device_context api
This commit is contained in:
parent
447c73b416
commit
8d91000807
|
@ -554,4 +554,5 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Check whether a mindir file can be loaded and up to date.");
|
||||
#endif
|
||||
(void)m.def("_ms_memory_recycle", &mindspore::pipeline::MemoryRecycle, "Recycle memory used by mindspore.");
|
||||
(void)m.def("_bind_device_ctx", &mindspore::pipeline::BindDeviceCtx, "Bind device context to current thread");
|
||||
}
|
||||
|
|
|
@ -1850,6 +1850,8 @@ void MemoryRecycle() {
|
|||
FuncGraphLoopBreaker::Inst().BreakLoop();
|
||||
}
|
||||
|
||||
void BindDeviceCtx() { device::DeviceContextManager::GetInstance().BindDeviceCtx(); }
|
||||
|
||||
void ClearResPart1() {
|
||||
runtime::OpExecutor::GetInstance().WorkerJoin();
|
||||
// When the python process exits, the kernels on the device may not have finished executing.
|
||||
|
|
|
@ -183,6 +183,7 @@ void FinalizeBackend();
|
|||
void ME_EXPORT ClearResAtexit();
|
||||
void CloseTsd(bool force = false);
|
||||
void MemoryRecycle();
|
||||
void BindDeviceCtx();
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode, const py::object decrypt = py::none(),
|
||||
|
|
|
@ -164,7 +164,7 @@ void AscendDeviceAddress::BindDevice() const {
|
|||
if (!device_name_.empty()) {
|
||||
auto ascend_device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(ascend_device_context);
|
||||
if (!ascend_device_context->device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
if (!ascend_device_context->device_res_manager_->BindDeviceToCurrentThread(false)) {
|
||||
MS_LOG(WARNING) << "Bind device to current thread failed.";
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -65,7 +65,7 @@ void AscendDeviceResManager::Destroy() {
|
|||
MS_LOG(INFO) << "Device resource manager Destroy success.";
|
||||
}
|
||||
|
||||
bool AscendDeviceResManager::BindDeviceToCurrentThread() const {
|
||||
bool AscendDeviceResManager::BindDeviceToCurrentThread(bool /* force_bind */) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
return true;
|
||||
|
|
|
@ -42,7 +42,7 @@ class AscendDeviceResManager : public DeviceResManager {
|
|||
void Destroy() override;
|
||||
|
||||
// set rt_context_ to this thread to control device
|
||||
bool BindDeviceToCurrentThread() const override;
|
||||
bool BindDeviceToCurrentThread(bool /* force_bind */) const override;
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
|
|
|
@ -366,7 +366,7 @@ bool AscendKernelExecutor::LaunchKernel(const CNodePtr &kernel, const vector<Add
|
|||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
KernelType kernel_type = AnfAlgo::GetKernelType(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
(void)res_manager_->BindDeviceToCurrentThread();
|
||||
(void)res_manager_->BindDeviceToCurrentThread(false);
|
||||
|
||||
std::vector<AddressPtr> real_inputs;
|
||||
bool ret = GetKernelRealInputs(kernel, inputs, &real_inputs);
|
||||
|
|
|
@ -93,7 +93,7 @@ bool GPUDeviceAddress::SyncHostToDevice(size_t size, const void *host_ptr) const
|
|||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
auto gpu_device_context = dynamic_cast<GPUDeviceContext *>(device_context);
|
||||
MS_EXCEPTION_IF_NULL(gpu_device_context);
|
||||
if (!gpu_device_context->device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
if (!gpu_device_context->device_res_manager_->BindDeviceToCurrentThread(false)) {
|
||||
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ static thread_local bool cur_thread_device_inited{false};
|
|||
|
||||
void GPUDeviceContext::Initialize() {
|
||||
if (initialized_) {
|
||||
if (!device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
if (!device_res_manager_->BindDeviceToCurrentThread(false)) {
|
||||
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
||||
}
|
||||
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
|
||||
|
@ -222,7 +222,7 @@ void GPUDeviceContext::Destroy() {
|
|||
|
||||
void *GPUDeviceResManager::AllocateMemory(size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
if (!BindDeviceToCurrentThread(false)) {
|
||||
return nullptr;
|
||||
}
|
||||
return mem_manager_->MallocMemFromMemPool(size, false);
|
||||
|
@ -247,7 +247,7 @@ bool GPUDeviceResManager::AllocateMemory(DeviceAddress *const &address) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
if (!BindDeviceToCurrentThread(false)) {
|
||||
return false;
|
||||
}
|
||||
if (auto_mem_offload_ != nullptr) {
|
||||
|
@ -264,7 +264,7 @@ bool GPUDeviceResManager::AllocateMemory(DeviceAddress *const &address) const {
|
|||
}
|
||||
|
||||
std::vector<void *> GPUDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
if (!BindDeviceToCurrentThread(false)) {
|
||||
std::vector<void *> ptr_list;
|
||||
return ptr_list;
|
||||
}
|
||||
|
@ -704,7 +704,7 @@ bool GPUKernelExecutor::LaunchKernel(const CNodePtr &kernel, const std::vector<A
|
|||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
size_t stream_id) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (!res_manager_->BindDeviceToCurrentThread()) {
|
||||
if (!res_manager_->BindDeviceToCurrentThread(false)) {
|
||||
return false;
|
||||
}
|
||||
bool ret = true;
|
||||
|
@ -856,8 +856,8 @@ bool GPUDeviceResManager::LoadCollectiveCommLib() {
|
|||
#endif
|
||||
}
|
||||
|
||||
bool GPUDeviceResManager::BindDeviceToCurrentThread() const {
|
||||
if (cur_thread_device_inited) {
|
||||
bool GPUDeviceResManager::BindDeviceToCurrentThread(bool force_bind) const {
|
||||
if (cur_thread_device_inited && !force_bind) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class GPUDeviceResManager : public DeviceResManager {
|
|||
// Release device memory, stream, cudnn and cublas handle, etc.
|
||||
void Destroy() override;
|
||||
|
||||
bool BindDeviceToCurrentThread() const override;
|
||||
bool BindDeviceToCurrentThread(bool force_bind) const override;
|
||||
|
||||
std::shared_ptr<void> AllocateHostMemory(size_t size) const override;
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
|
|||
// Launch custom func
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto custom_func = AnfUtils::GetCustomFunc(node);
|
||||
if (!device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
if (!device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread(false)) {
|
||||
std::string error_info = "BindDevice to current thread failed: " + node->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
|
||||
}
|
||||
|
|
|
@ -232,7 +232,7 @@ void EmbeddingCachePrefetchActor::Run() {
|
|||
// Bind device to current thread to gain device control privileges
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
|
||||
if (!device_context_->device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
if (!device_context_->device_res_manager_->BindDeviceToCurrentThread(false)) {
|
||||
MS_LOG(ERROR) << "Failed to bind device to current thread.";
|
||||
running_ = false;
|
||||
PsDataPrefetch::GetInstance().NotifyFinalize();
|
||||
|
|
|
@ -103,7 +103,9 @@ class BACKEND_EXPORT DeviceResManager {
|
|||
virtual void Destroy() {}
|
||||
|
||||
// Bind device to current thread to gain device control privileges
|
||||
virtual bool BindDeviceToCurrentThread() const { return true; }
|
||||
// If force_bind is true, bind context to current thread every time;
|
||||
// Otherwise, only bind context to current thread for the first time.
|
||||
virtual bool BindDeviceToCurrentThread(bool force_bind) const { return true; }
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
virtual void *AllocateMemory(size_t size) const = 0;
|
||||
|
|
|
@ -462,6 +462,14 @@ void DeviceContextManager::ClearDeviceContexts() {
|
|||
device_contexts_.clear();
|
||||
}
|
||||
|
||||
void DeviceContextManager::BindDeviceCtx() const {
|
||||
for (auto &iter : device_contexts_) {
|
||||
MS_EXCEPTION_IF_NULL(iter.second);
|
||||
MS_EXCEPTION_IF_NULL(iter.second->device_res_manager_);
|
||||
iter.second->device_res_manager_->BindDeviceToCurrentThread(true);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key,
|
||||
string jit_level /* ="" */) {
|
||||
std::string device_context_key_str = device_context_key.ToString();
|
||||
|
|
|
@ -57,6 +57,7 @@ class BACKEND_EXPORT DeviceContextManager {
|
|||
void WaitTaskFinishOnDevice() const;
|
||||
void UnloadPlugin();
|
||||
std::string GetErrorMsg() const;
|
||||
void BindDeviceCtx() const;
|
||||
|
||||
private:
|
||||
DeviceContextManager() = default;
|
||||
|
|
|
@ -38,7 +38,7 @@ from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|||
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
||||
from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, \
|
||||
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
||||
_ms_memory_recycle
|
||||
_ms_memory_recycle, _bind_device_ctx
|
||||
from mindspore.parallel._ps_context import _is_role_sched
|
||||
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
|
||||
_get_pipeline_stages, _is_in_auto_parallel_mode
|
||||
|
@ -1566,6 +1566,10 @@ def _generate_pair_password(obf_password):
|
|||
return obf_password, append_password
|
||||
|
||||
|
||||
def _bind_device_context():
|
||||
"""Bind device context to current thread"""
|
||||
_bind_device_ctx()
|
||||
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
_pynative_executor = _PyNativeExecutor()
|
||||
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _bind_device_context
|
||||
|
||||
|
||||
class Add(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Add, self).__init__()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.add(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bind_cuda_ctx_api():
|
||||
"""
|
||||
Feature: _bind_device_ctx api
|
||||
Description: Test _bind_device_ctx api.
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
x1 = Tensor(np.array([1]))
|
||||
x2 = Tensor(np.array([2]))
|
||||
|
||||
_bind_device_context()
|
||||
net = Add()
|
||||
output = net(x1, x2)
|
||||
assert output.asnumpy() == np.array([3])
|
||||
|
||||
_bind_device_context()
|
||||
output = net(x1, output)
|
||||
assert output.asnumpy() == np.array([4])
|
Loading…
Reference in New Issue