forked from mindspore-Ecosystem/mindspore
!48312 Add fused_tensor case to the MemUsageAnalyzer
Merge pull request !48312 from yangluhang/mem_analyze_fused_tensor
This commit is contained in:
commit
8514a58d6f
|
@ -25,3 +25,16 @@ def add_net(x1, x2, x3, x4, x5):
|
|||
sum4 = add(sum3, x5)
|
||||
ret = mul(sum4, sum1)
|
||||
return ret
|
||||
|
||||
|
||||
all_reduce = P.AllReduce().add_prim_attr("fusion", 1)
|
||||
mul = P.Mul()
|
||||
|
||||
|
||||
def all_reduce_net(x1, x2, x3):
|
||||
product = mul(x1, x2)
|
||||
sum1 = add(x2, x3)
|
||||
reduce1 = all_reduce(product)
|
||||
reduce2 = all_reduce(sum1)
|
||||
res = add(reduce1, reduce2)
|
||||
return res
|
||||
|
|
|
@ -59,4 +59,51 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
|
|||
|
||||
ASSERT_EQ(100, analyzer->LeastMemNeeded());
|
||||
}
|
||||
|
||||
/// Feature: MemUsageAnalyzer
|
||||
/// Description: Test MemUsageAnalyzer interface with allreduce node
|
||||
/// Expectation: Pass all interface test
|
||||
TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer_fused_tesnor) {
|
||||
auto net = get_py_func_("all_reduce_net");
|
||||
EXPECT_NE(net, nullptr);
|
||||
std::vector<int64_t> shp_x{2, 2, 2, 2};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
|
||||
|
||||
auto func_graph = GetFuncGraph(net, args_spec_list);
|
||||
auto kernel_graph = Compile(func_graph);
|
||||
|
||||
auto analyzer = std::make_shared<MemUsageAnalyzer>();
|
||||
analyzer->Analyze(kernel_graph);
|
||||
auto kernel_infos = analyzer->GetMemUsageKernelInfos();
|
||||
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
|
||||
|
||||
ASSERT_EQ(4, kernel_infos.size());
|
||||
ASSERT_EQ(14, tensor_infos.size());
|
||||
ASSERT_EQ(260, analyzer->LeastMemNeeded());
|
||||
|
||||
size_t comm_kernel_num = 0;
|
||||
for (size_t i = 0; i < kernel_infos.size(); ++i) {
|
||||
auto kernel_info = analyzer->GetMemUsageKernelInfo(i);
|
||||
ASSERT_NE(nullptr, kernel_info);
|
||||
if (kernel_info->is_comm_) {
|
||||
++comm_kernel_num;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(1, comm_kernel_num);
|
||||
|
||||
size_t tensor_max_used = 0;
|
||||
size_t fused_tensor_size = 0;
|
||||
for (size_t i = 0; i < tensor_infos.size(); ++i) {
|
||||
auto tensor_info = analyzer->GetMemUsageTensorInfo(i);
|
||||
ASSERT_NE(nullptr, tensor_info);
|
||||
auto used_size = tensor_info->used_by_kernels_.size();
|
||||
tensor_max_used = tensor_max_used < used_size ? used_size : tensor_max_used;
|
||||
if (tensor_info->fused_tensor_ids_.size() > 0) {
|
||||
++fused_tensor_size;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(2, tensor_max_used);
|
||||
ASSERT_EQ(2, fused_tensor_size);
|
||||
}
|
||||
} // namespace mindspore::device
|
Loading…
Reference in New Issue