forked from mindspore-Ecosystem/mindspore
!15563 [GraphKernel] fix precision error when open graph kernel
From: @zengzitao Reviewed-by: @limingqi107,@anyrenwei Signed-off-by: @anyrenwei
This commit is contained in:
commit
4189a0c06f
|
@ -484,6 +484,13 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs,
|
|||
if (node->isa<CNode>()) {
|
||||
auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize());
|
||||
AnfAlgo::SetOutputAddr(new_address, output_index, node.get());
|
||||
if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
auto runtime_instance =
|
||||
device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto gpu_runtime_instance = dynamic_cast<device::gpu::GPUKernelRuntime *>(runtime_instance);
|
||||
gpu_runtime_instance->SetAddrInvalid(address);
|
||||
}
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
|
|
|
@ -1220,12 +1220,21 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr
|
|||
addr_iter = iter;
|
||||
}
|
||||
|
||||
if (addr_iter->second[i] == nullptr) {
|
||||
auto &now_addr = addr_iter->second[i];
|
||||
if (now_addr == nullptr) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
|
||||
addr_iter->second[i] = device_address;
|
||||
now_addr = device_address;
|
||||
addr_state_[now_addr] = true;
|
||||
} else {
|
||||
auto addr_state_iter = addr_state_.find(now_addr);
|
||||
if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
|
||||
now_addr = device_address;
|
||||
addr_state_[now_addr] = true;
|
||||
}
|
||||
}
|
||||
|
||||
return addr_iter->second[i];
|
||||
return now_addr;
|
||||
}
|
||||
|
||||
DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
|
||||
|
@ -1244,12 +1253,21 @@ DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node,
|
|||
addr_iter = iter;
|
||||
}
|
||||
|
||||
if (addr_iter->second[i] == nullptr) {
|
||||
auto &now_addr = addr_iter->second[i];
|
||||
if (now_addr == nullptr) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||
addr_iter->second[i] = device_address;
|
||||
now_addr = device_address;
|
||||
addr_state_[now_addr] = true;
|
||||
} else {
|
||||
auto addr_state_iter = addr_state_.find(now_addr);
|
||||
if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||
now_addr = device_address;
|
||||
addr_state_[now_addr] = true;
|
||||
}
|
||||
}
|
||||
|
||||
return addr_iter->second[i];
|
||||
return now_addr;
|
||||
}
|
||||
|
||||
session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) {
|
||||
|
|
|
@ -50,6 +50,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; };
|
||||
void *compute_stream() const override { return stream_; }
|
||||
void *communication_stream() const override { return communication_stream_; }
|
||||
void SetAddrInvalid(const DeviceAddressPtr &addr) { addr_state_[addr] = false; }
|
||||
|
||||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
@ -121,6 +122,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
|
||||
bool enable_relation_cache_{false};
|
||||
|
||||
std::unordered_map<DeviceAddressPtr, bool> addr_state_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_;
|
||||
|
|
Loading…
Reference in New Issue