support group conv2d in pynative

This commit is contained in:
yuchaojie 2021-06-26 11:11:11 +08:00
parent 29e42efd98
commit bf13a8a01e
15 changed files with 47 additions and 35 deletions

View File

@ -204,6 +204,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>()); ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
} }
} // namespace } // namespace
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
@ -344,6 +345,20 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
} }
} }
void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
// data layout optimization
AscendDataLayout(kernel_graph);
// mixed precision optimization
AscendMixPrecision(kernel_graph);
// other optimization
auto optimizer = std::make_shared<GraphOptimizer>();
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
optimizer->AddPassManager(other_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);

View File

@ -21,6 +21,7 @@ namespace mindspore {
namespace opt { namespace opt {
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);

View File

@ -614,12 +614,9 @@ void AscendSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_
void AscendSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { Execute(kernel_graph, true); } void AscendSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { Execute(kernel_graph, true); }
void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const { void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start"; MS_LOG(INFO) << "HardwareOptimize Start";
// data layout optimization opt::RunOpAscendBackendOptimization(kernel_graph);
opt::AscendDataLayout(kernel_graph); MS_LOG(INFO) << "HardwareOptimize Finish";
// mixed precision optimization
opt::AscendMixPrecision(kernel_graph);
MS_LOG(INFO) << "Finish";
} }
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {

View File

@ -39,8 +39,8 @@ class AscendDeviceAddress : public DeviceAddress {
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {} : DeviceAddress(ptr, size, format, type_id) {}
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {} : DeviceAddress(ptr, size, format, type_id, node_index) {}
~AscendDeviceAddress() override; ~AscendDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override; bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
bool SyncHostToDevice(size_t size, const void *host_ptr) const override; bool SyncHostToDevice(size_t size, const void *host_ptr) const override;

View File

@ -334,8 +334,8 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
} }
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) { TypeId type_id, const KernelWithIndex &node_index) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index); return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
} }
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {

View File

@ -67,7 +67,7 @@ class AscendKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override; TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override; const KernelWithIndex &node_index) override;
bool KernelMemNotReuse(const AnfNodePtr &node) override; bool KernelMemNotReuse(const AnfNodePtr &node) override;
void KernelLaunchProfiling(const std::string &kernel_name) override; void KernelLaunchProfiling(const std::string &kernel_name) override;

View File

@ -31,9 +31,8 @@ class CPUDeviceAddress : public DeviceAddress {
CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {} : DeviceAddress(ptr, size, format, type_id) {}
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node, CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
size_t out_index) : DeviceAddress(ptr, size, format, type_id, node_index) {}
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
~CPUDeviceAddress() override = default; ~CPUDeviceAddress() override = default;

View File

@ -176,8 +176,8 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
} }
DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) { TypeId type_id, const KernelWithIndex &node_index) {
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index); return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
} }
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput( tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(

View File

@ -54,7 +54,7 @@ class CPUKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override; TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override; const KernelWithIndex &node_index) override;
private: private:
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,

View File

@ -66,9 +66,9 @@ class DeviceAddress : public mindspore::DeviceSync {
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {} explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {} : ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node, explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
size_t out_index) const KernelWithIndex &node_index)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_({node, out_index}) {} : ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_(node_index) {}
virtual ~DeviceAddress() { ptr_ = nullptr; } virtual ~DeviceAddress() { ptr_ = nullptr; }
const void *GetPtr() const { return ptr_; } const void *GetPtr() const { return ptr_; }
size_t GetSize() const { return size_; } size_t GetSize() const { return size_; }

View File

@ -34,9 +34,8 @@ class GPUDeviceAddress : public DeviceAddress {
GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {} : DeviceAddress(ptr, size, format, type_id) {}
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node, GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
size_t out_index) : DeviceAddress(ptr, size, format, type_id, node_index) {}
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
~GPUDeviceAddress() override; ~GPUDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override; bool SyncDeviceToHost(size_t size, void *host_ptr) const override;

View File

@ -212,8 +212,8 @@ DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
} }
DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) { TypeId type_id, const KernelWithIndex &node_index) {
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index); return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
} }
bool GPUKernelRuntime::InitDevice() { bool GPUKernelRuntime::InitDevice() {

View File

@ -59,7 +59,7 @@ class GPUKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override; TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override; const KernelWithIndex &node_index) override;
bool SyncStream() override; bool SyncStream() override;
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override; bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;

View File

@ -18,6 +18,7 @@
#include <functional> #include <functional>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <set>
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
@ -187,7 +188,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
} }
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address = auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index); CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
@ -220,7 +221,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
} }
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, kernel, i); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
@ -358,14 +359,14 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
MS_EXCEPTION_IF_NULL(address.addr); MS_EXCEPTION_IF_NULL(address.addr);
device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
output_type_id, item, index); output_type_id, {item, index});
AnfAlgo::SetOutputAddr(device_address, index, item.get()); AnfAlgo::SetOutputAddr(device_address, index, item.get());
continue; continue;
} }
#endif #endif
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
device_address = device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index); CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
<< " node:" << item->fullname_with_scope() << " index: " << index; << " node:" << item->fullname_with_scope() << " index: " << index;
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
@ -536,7 +537,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
for (size_t j = 0; j < align_size_list.size(); ++j) { for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j); std::string output_format = AnfAlgo::GetOutputFormat(node, j);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, node, j); auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
if (output_ptr == nullptr) { if (output_ptr == nullptr) {
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true); output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
@ -572,7 +573,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
} }
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, anf_node, index); auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, {anf_node, index});
AnfAlgo::SetOutputAddr(address, index, anf_node.get()); AnfAlgo::SetOutputAddr(address, index, anf_node.get());
return address; return address;
} }
@ -665,7 +666,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
} }
std::string output_format = AnfAlgo::GetOutputFormat(node, i); std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, node, i); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false); uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
@ -706,7 +707,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
} }
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
DeviceAddressPtr address = DeviceAddressPtr address =
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, value_node, output_idx); CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) && if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, node_size)) { !mem_manager_->MallocMemFromMemPool(address, node_size)) {

View File

@ -115,7 +115,7 @@ class KernelRuntime {
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0; TypeId type_id) = 0;
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) = 0; TypeId type_id, const KernelWithIndex &node_index) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool KernelMemNotReuse(const AnfNodePtr &node); virtual bool KernelMemNotReuse(const AnfNodePtr &node);