forked from mindspore-Ecosystem/mindspore
!15563 [GraphKernel] fix precision error when open graph kernel
From: @zengzitao Reviewed-by: @limingqi107,@anyrenwei Signed-off-by: @anyrenwei
This commit is contained in:
commit
4189a0c06f
|
@ -484,6 +484,13 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs,
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize());
|
auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize());
|
||||||
AnfAlgo::SetOutputAddr(new_address, output_index, node.get());
|
AnfAlgo::SetOutputAddr(new_address, output_index, node.get());
|
||||||
|
if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||||
|
auto runtime_instance =
|
||||||
|
device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
||||||
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||||
|
auto gpu_runtime_instance = dynamic_cast<device::gpu::GPUKernelRuntime *>(runtime_instance);
|
||||||
|
gpu_runtime_instance->SetAddrInvalid(address);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AnfAlgo::IsDynamicShape(node)) {
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
|
|
@ -1220,12 +1220,21 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr
|
||||||
addr_iter = iter;
|
addr_iter = iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (addr_iter->second[i] == nullptr) {
|
auto &now_addr = addr_iter->second[i];
|
||||||
|
if (now_addr == nullptr) {
|
||||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
|
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
|
||||||
addr_iter->second[i] = device_address;
|
now_addr = device_address;
|
||||||
|
addr_state_[now_addr] = true;
|
||||||
|
} else {
|
||||||
|
auto addr_state_iter = addr_state_.find(now_addr);
|
||||||
|
if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) {
|
||||||
|
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
|
||||||
|
now_addr = device_address;
|
||||||
|
addr_state_[now_addr] = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return addr_iter->second[i];
|
return now_addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
|
DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
|
||||||
|
@ -1244,12 +1253,21 @@ DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node,
|
||||||
addr_iter = iter;
|
addr_iter = iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (addr_iter->second[i] == nullptr) {
|
auto &now_addr = addr_iter->second[i];
|
||||||
|
if (now_addr == nullptr) {
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||||
addr_iter->second[i] = device_address;
|
now_addr = device_address;
|
||||||
|
addr_state_[now_addr] = true;
|
||||||
|
} else {
|
||||||
|
auto addr_state_iter = addr_state_.find(now_addr);
|
||||||
|
if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) {
|
||||||
|
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||||
|
now_addr = device_address;
|
||||||
|
addr_state_[now_addr] = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return addr_iter->second[i];
|
return now_addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) {
|
session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) {
|
||||||
|
|
|
@ -50,6 +50,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; };
|
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; };
|
||||||
void *compute_stream() const override { return stream_; }
|
void *compute_stream() const override { return stream_; }
|
||||||
void *communication_stream() const override { return communication_stream_; }
|
void *communication_stream() const override { return communication_stream_; }
|
||||||
|
void SetAddrInvalid(const DeviceAddressPtr &addr) { addr_state_[addr] = false; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
|
@ -121,6 +122,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
|
|
||||||
bool enable_relation_cache_{false};
|
bool enable_relation_cache_{false};
|
||||||
|
|
||||||
|
std::unordered_map<DeviceAddressPtr, bool> addr_state_;
|
||||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_;
|
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_;
|
||||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_;
|
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_;
|
||||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_;
|
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_;
|
||||||
|
|
Loading…
Reference in New Issue