!15563 [GraphKernel] fix precision error when open graph kernel

From: @zengzitao
Reviewed-by: @limingqi107,@anyrenwei
Signed-off-by: @anyrenwei
This commit is contained in:
mindspore-ci-bot 2021-04-23 14:19:40 +08:00 committed by Gitee
commit 4189a0c06f
3 changed files with 33 additions and 6 deletions

View File

@ -484,6 +484,13 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs,
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize()); auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize());
AnfAlgo::SetOutputAddr(new_address, output_index, node.get()); AnfAlgo::SetOutputAddr(new_address, output_index, node.get());
if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
auto runtime_instance =
device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto gpu_runtime_instance = dynamic_cast<device::gpu::GPUKernelRuntime *>(runtime_instance);
gpu_runtime_instance->SetAddrInvalid(address);
}
} }
if (AnfAlgo::IsDynamicShape(node)) { if (AnfAlgo::IsDynamicShape(node)) {

View File

@ -1220,12 +1220,21 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr
addr_iter = iter; addr_iter = iter;
} }
if (addr_iter->second[i] == nullptr) { auto &now_addr = addr_iter->second[i];
if (now_addr == nullptr) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
addr_iter->second[i] = device_address; now_addr = device_address;
addr_state_[now_addr] = true;
} else {
auto addr_state_iter = addr_state_.find(now_addr);
if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
now_addr = device_address;
addr_state_[now_addr] = true;
}
} }
return addr_iter->second[i]; return now_addr;
} }
DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) { DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
@ -1244,12 +1253,21 @@ DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node,
addr_iter = iter; addr_iter = iter;
} }
if (addr_iter->second[i] == nullptr) { auto &now_addr = addr_iter->second[i];
if (now_addr == nullptr) {
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node); auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
addr_iter->second[i] = device_address; now_addr = device_address;
addr_state_[now_addr] = true;
} else {
auto addr_state_iter = addr_state_.find(now_addr);
if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) {
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
now_addr = device_address;
addr_state_[now_addr] = true;
}
} }
return addr_iter->second[i]; return now_addr;
} }
session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) { session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) {

View File

@ -50,6 +50,7 @@ class GPUKernelRuntime : public KernelRuntime {
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; }; DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; };
void *compute_stream() const override { return stream_; } void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; } void *communication_stream() const override { return communication_stream_; }
void SetAddrInvalid(const DeviceAddressPtr &addr) { addr_state_[addr] = false; }
protected: protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
@ -121,6 +122,7 @@ class GPUKernelRuntime : public KernelRuntime {
bool enable_relation_cache_{false}; bool enable_relation_cache_{false};
std::unordered_map<DeviceAddressPtr, bool> addr_state_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_; std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_; std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_; std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_;