!48283 add optimize graph for ut test kernel executor

Merge pull request !48283 from kisnwang/add-device-common
This commit is contained in:
i-robot 2023-02-01 07:48:33 +00:00 committed by Gitee
commit a9d68b1101
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 62 additions and 32 deletions

View File

@ -17,9 +17,13 @@
#ifndef TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
#define TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
#include <memory>
#include "common/common_test.h"
#include "abstract/abstract_function.h"
#include "runtime/graph_scheduler/control_node_parser.h"
#include "backend/common/optimizer/graph_optimizer.h"
#include "backend/common/pass/communication_op_fusion.h"
#include "backend/graph_compiler/backend.h"
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
@ -88,41 +92,28 @@ class TestKernelExecutor : public device::KernelExecutor {
public:
TestKernelExecutor() = default;
~TestKernelExecutor() override = default;
virtual void OptimizeGraph(const FuncGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &nodes = kernel_graph->execution_order();
for (const auto node : nodes) {
MS_EXCEPTION_IF_NULL(node);
SetKernelInfo(node);
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
for (const auto node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
node->set_kernel_info(kernel_info);
}
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (kernel_info->select_kernel_build_info() == nullptr) {
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
(void)inputs_format.emplace_back(kOpFormat_DEFAULT);
(void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
size_t output_num = AnfAlgo::GetOutputElementNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
(void)outputs_format.emplace_back(kOpFormat_DEFAULT);
(void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(node, output_index));
}
builder->SetOriginDataFormat(kOpFormat_DEFAULT);
builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type);
kernel_info->set_select_kernel_build_info(builder->Build());
}
SetKernelInfo(node);
std::vector<size_t> input_size_list;
std::vector<size_t> output_size_list;
@ -152,6 +143,45 @@ class TestKernelExecutor : public device::KernelExecutor {
AnfAlgo::SetWorkspaceAddr(std::make_shared<TestDeviceAddress>(nullptr, kDefaultWorkSpaceSize), 0, node.get());
}
}
private:
void SetKernelInfo(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
if (node->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
node->set_kernel_info(kernel_info);
}
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
if (kernel_info->select_kernel_build_info() != nullptr) {
return;
}
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
(void)inputs_format.emplace_back(kOpFormat_DEFAULT);
(void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
size_t output_num = AnfAlgo::GetOutputElementNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
(void)outputs_format.emplace_back(kOpFormat_DEFAULT);
(void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(node, output_index));
}
builder->SetOriginDataFormat(kOpFormat_DEFAULT);
builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type);
kernel_info->set_select_kernel_build_info(builder->Build());
}
};
class TestDeviceContext : public device::DeviceInterface<TestKernelExecutor, TestDeviceResManager> {