forked from mindspore-Ecosystem/mindspore
!4987 [bugfix]SyncDeviceToHost failed when device address size is zero
Merge pull request !4987 from zyli2020/bug_fix
This commit is contained in:
commit
33c7b49219
|
@ -65,6 +65,8 @@ void GPUSession::StartKernelRT() const {
|
|||
|
||||
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
|
@ -73,9 +75,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
}
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -32,6 +32,10 @@ namespace device {
|
|||
namespace gpu {
|
||||
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
bool need_sync = (size != 0) && (size_ != 0);
|
||||
if (!need_sync) {
|
||||
return true;
|
||||
}
|
||||
auto &stream = GPUDeviceManager::GetInstance().default_stream();
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
|
||||
|
@ -48,6 +52,10 @@ bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, T
|
|||
|
||||
bool GPUDeviceAddress::SyncHostToDevice(const std::vector<int> &, size_t size, TypeId, const void *host_ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
bool need_sync = (size != 0) && (size_ != 0);
|
||||
if (!need_sync) {
|
||||
return true;
|
||||
}
|
||||
auto &stream = GPUDeviceManager::GetInstance().default_stream();
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
if (size != size_) {
|
||||
|
|
Loading…
Reference in New Issue