forked from mindspore-Ecosystem/mindspore
!1937 excute ops asynchronously in pynative mode
Merge pull request !1937 from chujinjin/add_async_ops_excute_for_pynative
This commit is contained in:
commit
18e2a0f12e
|
@ -92,10 +92,29 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
|
|||
return true;
|
||||
}
|
||||
|
||||
void AscendDeviceAddress::SyncStream() const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->device_id();
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto ret = runtime_instance->SyncStream();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Sync stream error!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t size, mindspore::TypeId type,
|
||||
void *host_ptr) const {
|
||||
MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
|
||||
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
SyncStream();
|
||||
}
|
||||
bool sync_ok = false;
|
||||
std::vector<size_t> host_shape;
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);
|
||||
|
|
|
@ -44,6 +44,7 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
|
||||
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
|
||||
const void *host_ptr) const;
|
||||
void SyncStream() const;
|
||||
};
|
||||
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
|
||||
} // namespace ascend
|
||||
|
|
|
@ -41,12 +41,12 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
bool RunTask(const session::KernelGraph *graph) override;
|
||||
bool LoadTask(const session::KernelGraph *graph) override;
|
||||
void ClearGraphRuntimeResource(uint32_t graph_id) override;
|
||||
bool SyncStream() override;
|
||||
|
||||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id) override;
|
||||
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
|
||||
bool SyncStream() override;
|
||||
|
||||
private:
|
||||
bool InitDevice();
|
||||
|
|
|
@ -28,6 +28,13 @@ namespace device {
|
|||
namespace gpu {
|
||||
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
auto &stream = GPUDeviceManager::GetInstance().default_stream();
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "SyncStream failed";
|
||||
return ret;
|
||||
}
|
||||
if (size != size_) {
|
||||
MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_;
|
||||
return true;
|
||||
|
|
|
@ -680,10 +680,6 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
|
|||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
return false;
|
||||
}
|
||||
if (!SyncStream()) {
|
||||
MS_LOG(ERROR) << "SyncStream failed!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ class KernelRuntime {
|
|||
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
|
||||
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
|
||||
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
|
||||
virtual bool SyncStream() = 0;
|
||||
|
||||
#ifdef ENABLE_DUMP_E2E
|
||||
DumpConfPtr GetDumpConf();
|
||||
|
@ -68,7 +69,6 @@ class KernelRuntime {
|
|||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id) = 0;
|
||||
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
|
||||
virtual bool SyncStream() = 0;
|
||||
void AssignStaticMemory(session::KernelGraph *graph);
|
||||
void AssignDynamicMemory(session::KernelGraph *graph);
|
||||
void ReuseAssignDynamicMemory(session::KernelGraph *graph);
|
||||
|
|
Loading…
Reference in New Issue