!18906 support group conv2d in pynative

Merge pull request !18906 from yuchaojie/op_select2
This commit is contained in:
i-robot 2021-06-26 08:10:57 +00:00 committed by Gitee
commit 80003511a7
15 changed files with 47 additions and 35 deletions

View File

@ -204,6 +204,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
}
} // namespace
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
@ -344,6 +345,20 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
}
}
void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
// data layout optimization
AscendDataLayout(kernel_graph);
// mixed precision optimization
AscendMixPrecision(kernel_graph);
// other optimization
auto optimizer = std::make_shared<GraphOptimizer>();
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
optimizer->AddPassManager(other_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);

View File

@ -21,6 +21,7 @@ namespace mindspore {
namespace opt {
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);

View File

@ -614,12 +614,9 @@ void AscendSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_
void AscendSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { Execute(kernel_graph, true); }
void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start";
// data layout optimization
opt::AscendDataLayout(kernel_graph);
// mixed precision optimization
opt::AscendMixPrecision(kernel_graph);
MS_LOG(INFO) << "Finish";
MS_LOG(INFO) << "HardwareOptimize Start";
opt::RunOpAscendBackendOptimization(kernel_graph);
MS_LOG(INFO) << "HardwareOptimize Finish";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {

View File

@ -39,8 +39,8 @@ class AscendDeviceAddress : public DeviceAddress {
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node_index) {}
~AscendDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;

View File

@ -334,8 +334,8 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
}
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
TypeId type_id, const KernelWithIndex &node_index) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
}
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {

View File

@ -67,7 +67,7 @@ class AscendKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override;
const KernelWithIndex &node_index) override;
bool KernelMemNotReuse(const AnfNodePtr &node) override;
void KernelLaunchProfiling(const std::string &kernel_name) override;

View File

@ -31,9 +31,8 @@ class CPUDeviceAddress : public DeviceAddress {
CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
size_t out_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node_index) {}
~CPUDeviceAddress() override = default;

View File

@ -176,8 +176,8 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
}
DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
TypeId type_id, const KernelWithIndex &node_index) {
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
}
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(

View File

@ -54,7 +54,7 @@ class CPUKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override;
const KernelWithIndex &node_index) override;
private:
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,

View File

@ -66,9 +66,9 @@ class DeviceAddress : public mindspore::DeviceSync {
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
size_t out_index)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_({node, out_index}) {}
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const KernelWithIndex &node_index)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_(node_index) {}
virtual ~DeviceAddress() { ptr_ = nullptr; }
const void *GetPtr() const { return ptr_; }
size_t GetSize() const { return size_; }

View File

@ -34,9 +34,8 @@ class GPUDeviceAddress : public DeviceAddress {
GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
size_t out_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node_index) {}
~GPUDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;

View File

@ -212,8 +212,8 @@ DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
}
DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
TypeId type_id, const KernelWithIndex &node_index) {
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
}
bool GPUKernelRuntime::InitDevice() {

View File

@ -59,7 +59,7 @@ class GPUKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override;
const KernelWithIndex &node_index) override;
bool SyncStream() override;
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;

View File

@ -18,6 +18,7 @@
#include <functional>
#include <utility>
#include <vector>
#include <set>
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
@ -187,7 +188,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
}
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
@ -220,7 +221,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
}
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
@ -358,14 +359,14 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
MS_EXCEPTION_IF_NULL(address.addr);
device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
output_type_id, item, index);
output_type_id, {item, index});
AnfAlgo::SetOutputAddr(device_address, index, item.get());
continue;
}
#endif
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
<< " node:" << item->fullname_with_scope() << " index: " << index;
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
@ -536,7 +537,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, node, j);
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
MS_EXCEPTION_IF_NULL(address);
if (output_ptr == nullptr) {
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
@ -572,7 +573,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
}
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, anf_node, index);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, {anf_node, index});
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
return address;
}
@ -665,7 +666,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
}
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, node, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
MS_EXCEPTION_IF_NULL(device_address);
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
MS_EXCEPTION_IF_NULL(ptr);
@ -706,7 +707,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
}
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
DeviceAddressPtr address =
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, value_node, output_idx);
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
MS_EXCEPTION_IF_NULL(address);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, node_size)) {

View File

@ -115,7 +115,7 @@ class KernelRuntime {
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0;
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) = 0;
TypeId type_id, const KernelWithIndex &node_index) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool KernelMemNotReuse(const AnfNodePtr &node);