support group conv2d in pynative
This commit is contained in:
parent
29e42efd98
commit
bf13a8a01e
|
@ -204,6 +204,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
||||||
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||||
|
@ -344,6 +345,20 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
|
// data layout optimization
|
||||||
|
AscendDataLayout(kernel_graph);
|
||||||
|
// mixed precision optimization
|
||||||
|
AscendMixPrecision(kernel_graph);
|
||||||
|
// other optimization
|
||||||
|
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||||
|
auto other_pm = std::make_shared<PassManager>("other_pm");
|
||||||
|
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
|
||||||
|
optimizer->AddPassManager(other_pm);
|
||||||
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
|
kernel_graph->SetExecOrderByDefault();
|
||||||
|
}
|
||||||
|
|
||||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
|
|
@ -21,6 +21,7 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
|
void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
|
|
|
@ -614,12 +614,9 @@ void AscendSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_
|
||||||
void AscendSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { Execute(kernel_graph, true); }
|
void AscendSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { Execute(kernel_graph, true); }
|
||||||
|
|
||||||
void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
|
void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
|
||||||
MS_LOG(INFO) << "Start";
|
MS_LOG(INFO) << "HardwareOptimize Start";
|
||||||
// data layout optimization
|
opt::RunOpAscendBackendOptimization(kernel_graph);
|
||||||
opt::AscendDataLayout(kernel_graph);
|
MS_LOG(INFO) << "HardwareOptimize Finish";
|
||||||
// mixed precision optimization
|
|
||||||
opt::AscendMixPrecision(kernel_graph);
|
|
||||||
MS_LOG(INFO) << "Finish";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
|
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
|
||||||
|
|
|
@ -39,8 +39,8 @@ class AscendDeviceAddress : public DeviceAddress {
|
||||||
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
|
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
|
||||||
: DeviceAddress(ptr, size, format, type_id) {}
|
: DeviceAddress(ptr, size, format, type_id) {}
|
||||||
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
||||||
const AnfNodePtr &node, size_t out_index)
|
const KernelWithIndex &node_index)
|
||||||
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
|
: DeviceAddress(ptr, size, format, type_id, node_index) {}
|
||||||
~AscendDeviceAddress() override;
|
~AscendDeviceAddress() override;
|
||||||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||||
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
|
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
|
||||||
|
|
|
@ -334,8 +334,8 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
|
TypeId type_id, const KernelWithIndex &node_index) {
|
||||||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
|
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
||||||
|
|
|
@ -67,7 +67,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id) override;
|
TypeId type_id) override;
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||||
const AnfNodePtr &node, size_t out_index) override;
|
const KernelWithIndex &node_index) override;
|
||||||
bool KernelMemNotReuse(const AnfNodePtr &node) override;
|
bool KernelMemNotReuse(const AnfNodePtr &node) override;
|
||||||
|
|
||||||
void KernelLaunchProfiling(const std::string &kernel_name) override;
|
void KernelLaunchProfiling(const std::string &kernel_name) override;
|
||||||
|
|
|
@ -31,9 +31,8 @@ class CPUDeviceAddress : public DeviceAddress {
|
||||||
CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||||
: DeviceAddress(ptr, size, format, type_id) {}
|
: DeviceAddress(ptr, size, format, type_id) {}
|
||||||
|
|
||||||
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
|
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
|
||||||
size_t out_index)
|
: DeviceAddress(ptr, size, format, type_id, node_index) {}
|
||||||
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
|
|
||||||
|
|
||||||
~CPUDeviceAddress() override = default;
|
~CPUDeviceAddress() override = default;
|
||||||
|
|
||||||
|
|
|
@ -176,8 +176,8 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
|
TypeId type_id, const KernelWithIndex &node_index) {
|
||||||
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
|
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
||||||
|
|
|
@ -54,7 +54,7 @@ class CPUKernelRuntime : public KernelRuntime {
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id) override;
|
TypeId type_id) override;
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||||
const AnfNodePtr &node, size_t out_index) override;
|
const KernelWithIndex &node_index) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
|
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
|
||||||
|
|
|
@ -66,9 +66,9 @@ class DeviceAddress : public mindspore::DeviceSync {
|
||||||
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
|
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
|
||||||
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||||
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
|
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
|
||||||
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
|
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
||||||
size_t out_index)
|
const KernelWithIndex &node_index)
|
||||||
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_({node, out_index}) {}
|
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_(node_index) {}
|
||||||
virtual ~DeviceAddress() { ptr_ = nullptr; }
|
virtual ~DeviceAddress() { ptr_ = nullptr; }
|
||||||
const void *GetPtr() const { return ptr_; }
|
const void *GetPtr() const { return ptr_; }
|
||||||
size_t GetSize() const { return size_; }
|
size_t GetSize() const { return size_; }
|
||||||
|
|
|
@ -34,9 +34,8 @@ class GPUDeviceAddress : public DeviceAddress {
|
||||||
GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||||
GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||||
: DeviceAddress(ptr, size, format, type_id) {}
|
: DeviceAddress(ptr, size, format, type_id) {}
|
||||||
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
|
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
|
||||||
size_t out_index)
|
: DeviceAddress(ptr, size, format, type_id, node_index) {}
|
||||||
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
|
|
||||||
~GPUDeviceAddress() override;
|
~GPUDeviceAddress() override;
|
||||||
|
|
||||||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||||
|
|
|
@ -212,8 +212,8 @@ DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
|
TypeId type_id, const KernelWithIndex &node_index) {
|
||||||
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
|
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUKernelRuntime::InitDevice() {
|
bool GPUKernelRuntime::InitDevice() {
|
||||||
|
|
|
@ -59,7 +59,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id) override;
|
TypeId type_id) override;
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||||
const AnfNodePtr &node, size_t out_index) override;
|
const KernelWithIndex &node_index) override;
|
||||||
bool SyncStream() override;
|
bool SyncStream() override;
|
||||||
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;
|
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
|
@ -187,7 +188,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
|
||||||
}
|
}
|
||||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||||
auto device_address =
|
auto device_address =
|
||||||
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
|
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
|
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
|
||||||
|
@ -220,7 +221,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
|
||||||
}
|
}
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, kernel, i);
|
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
|
||||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
|
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
|
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
|
||||||
|
@ -358,14 +359,14 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
||||||
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
|
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
|
||||||
MS_EXCEPTION_IF_NULL(address.addr);
|
MS_EXCEPTION_IF_NULL(address.addr);
|
||||||
device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
|
device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
|
||||||
output_type_id, item, index);
|
output_type_id, {item, index});
|
||||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||||
device_address =
|
device_address =
|
||||||
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
|
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
|
||||||
MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
|
MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
|
||||||
<< " node:" << item->fullname_with_scope() << " index: " << index;
|
<< " node:" << item->fullname_with_scope() << " index: " << index;
|
||||||
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
|
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
|
||||||
|
@ -536,7 +537,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
||||||
for (size_t j = 0; j < align_size_list.size(); ++j) {
|
for (size_t j = 0; j < align_size_list.size(); ++j) {
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
||||||
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, node, j);
|
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
if (output_ptr == nullptr) {
|
if (output_ptr == nullptr) {
|
||||||
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
|
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
|
||||||
|
@ -572,7 +573,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
|
||||||
}
|
}
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
|
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
||||||
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, anf_node, index);
|
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, {anf_node, index});
|
||||||
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
|
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
|
||||||
return address;
|
return address;
|
||||||
}
|
}
|
||||||
|
@ -665,7 +666,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
|
||||||
}
|
}
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
|
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
||||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, node, i);
|
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
|
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
|
||||||
MS_EXCEPTION_IF_NULL(ptr);
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
|
@ -706,7 +707,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
||||||
}
|
}
|
||||||
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||||
DeviceAddressPtr address =
|
DeviceAddressPtr address =
|
||||||
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, value_node, output_idx);
|
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
||||||
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
||||||
|
|
|
@ -115,7 +115,7 @@ class KernelRuntime {
|
||||||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id) = 0;
|
TypeId type_id) = 0;
|
||||||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) = 0;
|
TypeId type_id, const KernelWithIndex &node_index) = 0;
|
||||||
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
|
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
|
||||||
virtual bool KernelMemNotReuse(const AnfNodePtr &node);
|
virtual bool KernelMemNotReuse(const AnfNodePtr &node);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue