!48312 Add fused_tensor case to the MemUsageAnalyzer

Merge pull request !48312 from yangluhang/mem_analyze_fused_tensor
This commit is contained in:
i-robot 2023-02-02 03:40:49 +00:00 committed by Gitee
commit 8514a58d6f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 60 additions and 0 deletions

View File

@ -25,3 +25,16 @@ def add_net(x1, x2, x3, x4, x5):
sum4 = add(sum3, x5)
ret = mul(sum4, sum1)
return ret
all_reduce = P.AllReduce().add_prim_attr("fusion", 1)
mul = P.Mul()
def all_reduce_net(x1, x2, x3):
product = mul(x1, x2)
sum1 = add(x2, x3)
reduce1 = all_reduce(product)
reduce2 = all_reduce(sum1)
res = add(reduce1, reduce2)
return res

View File

@ -59,4 +59,51 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
ASSERT_EQ(100, analyzer->LeastMemNeeded());
}
/// Feature: MemUsageAnalyzer
/// Description: Test MemUsageAnalyzer interface with allreduce node
/// Expectation: Pass all interface test
TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer_fused_tesnor) {
auto net = get_py_func_("all_reduce_net");
EXPECT_NE(net, nullptr);
std::vector<int64_t> shp_x{2, 2, 2, 2};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
auto func_graph = GetFuncGraph(net, args_spec_list);
auto kernel_graph = Compile(func_graph);
auto analyzer = std::make_shared<MemUsageAnalyzer>();
analyzer->Analyze(kernel_graph);
auto kernel_infos = analyzer->GetMemUsageKernelInfos();
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
ASSERT_EQ(4, kernel_infos.size());
ASSERT_EQ(14, tensor_infos.size());
ASSERT_EQ(260, analyzer->LeastMemNeeded());
size_t comm_kernel_num = 0;
for (size_t i = 0; i < kernel_infos.size(); ++i) {
auto kernel_info = analyzer->GetMemUsageKernelInfo(i);
ASSERT_NE(nullptr, kernel_info);
if (kernel_info->is_comm_) {
++comm_kernel_num;
}
}
ASSERT_EQ(1, comm_kernel_num);
size_t tensor_max_used = 0;
size_t fused_tensor_size = 0;
for (size_t i = 0; i < tensor_infos.size(); ++i) {
auto tensor_info = analyzer->GetMemUsageTensorInfo(i);
ASSERT_NE(nullptr, tensor_info);
auto used_size = tensor_info->used_by_kernels_.size();
tensor_max_used = tensor_max_used < used_size ? used_size : tensor_max_used;
if (tensor_info->fused_tensor_ids_.size() > 0) {
++fused_tensor_size;
}
}
ASSERT_EQ(2, tensor_max_used);
ASSERT_EQ(2, fused_tensor_size);
}
} // namespace mindspore::device