!1937 excute ops asynchronously in pynative mode

Merge pull request !1937 from chujinjin/add_async_ops_excute_for_pynative
This commit is contained in:
mindspore-ci-bot 2020-06-10 18:47:56 +08:00 committed by Gitee
commit 18e2a0f12e
6 changed files with 29 additions and 6 deletions

View File

@ -92,10 +92,29 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
return true;
}
void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream();
if (!ret) {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
MS_LOG(INFO) << "Finish!";
}
bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t size, mindspore::TypeId type,
void *host_ptr) const {
MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
SyncStream();
}
bool sync_ok = false;
std::vector<size_t> host_shape;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);

View File

@ -44,6 +44,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const;
void SyncStream() const;
};
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
} // namespace ascend

View File

@ -41,12 +41,12 @@ class AscendKernelRuntime : public KernelRuntime {
bool RunTask(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph) override;
void ClearGraphRuntimeResource(uint32_t graph_id) override;
bool SyncStream() override;
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool SyncStream() override;
private:
bool InitDevice();

View File

@ -28,6 +28,13 @@ namespace device {
namespace gpu {
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr);
auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream);
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
if (!ret) {
MS_LOG(ERROR) << "SyncStream failed";
return ret;
}
if (size != size_) {
MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_;
return true;

View File

@ -680,10 +680,6 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false;
}
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed!";
return false;
}
return true;
}

View File

@ -55,6 +55,7 @@ class KernelRuntime {
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
virtual bool SyncStream() = 0;
#ifdef ENABLE_DUMP_E2E
DumpConfPtr GetDumpConf();
@ -68,7 +69,6 @@ class KernelRuntime {
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool SyncStream() = 0;
void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph);