forked from mindspore-Ecosystem/mindspore
!48283 add optimize graph for ut test kernel executor
Merge pull request !48283 from kisnwang/add-device-common
This commit is contained in:
commit
a9d68b1101
|
@ -17,9 +17,13 @@
|
|||
#ifndef TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
|
||||
#define TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "runtime/graph_scheduler/control_node_parser.h"
|
||||
#include "backend/common/optimizer/graph_optimizer.h"
|
||||
#include "backend/common/pass/communication_op_fusion.h"
|
||||
#include "backend/graph_compiler/backend.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
|
@ -88,41 +92,28 @@ class TestKernelExecutor : public device::KernelExecutor {
|
|||
public:
|
||||
TestKernelExecutor() = default;
|
||||
~TestKernelExecutor() override = default;
|
||||
|
||||
virtual void OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto &nodes = kernel_graph->execution_order();
|
||||
for (const auto node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
SetKernelInfo(node);
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
node->set_kernel_info(kernel_info);
|
||||
}
|
||||
|
||||
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
if (kernel_info->select_kernel_build_info() == nullptr) {
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
(void)inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
}
|
||||
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
(void)outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
}
|
||||
|
||||
builder->SetOriginDataFormat(kOpFormat_DEFAULT);
|
||||
builder->SetInputsFormat(inputs_format);
|
||||
builder->SetInputsDeviceType(inputs_type);
|
||||
builder->SetOutputsFormat(outputs_format);
|
||||
builder->SetOutputsDeviceType(outputs_type);
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
}
|
||||
SetKernelInfo(node);
|
||||
|
||||
std::vector<size_t> input_size_list;
|
||||
std::vector<size_t> output_size_list;
|
||||
|
@ -152,6 +143,45 @@ class TestKernelExecutor : public device::KernelExecutor {
|
|||
AnfAlgo::SetWorkspaceAddr(std::make_shared<TestDeviceAddress>(nullptr, kDefaultWorkSpaceSize), 0, node.get());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void SetKernelInfo(const CNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
node->set_kernel_info(kernel_info);
|
||||
}
|
||||
|
||||
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
if (kernel_info->select_kernel_build_info() != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
(void)inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
}
|
||||
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
(void)outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
}
|
||||
|
||||
builder->SetOriginDataFormat(kOpFormat_DEFAULT);
|
||||
builder->SetInputsFormat(inputs_format);
|
||||
builder->SetInputsDeviceType(inputs_type);
|
||||
builder->SetOutputsFormat(outputs_format);
|
||||
builder->SetOutputsDeviceType(outputs_type);
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
}
|
||||
};
|
||||
|
||||
class TestDeviceContext : public device::DeviceInterface<TestKernelExecutor, TestDeviceResManager> {
|
||||
|
|
Loading…
Reference in New Issue