forked from mindspore-Ecosystem/mindspore
!48312 Add fused_tensor case to the MemUsageAnalyzer
Merge pull request !48312 from yangluhang/mem_analyze_fused_tensor
This commit is contained in:
commit
8514a58d6f
|
@ -25,3 +25,16 @@ def add_net(x1, x2, x3, x4, x5):
|
||||||
sum4 = add(sum3, x5)
|
sum4 = add(sum3, x5)
|
||||||
ret = mul(sum4, sum1)
|
ret = mul(sum4, sum1)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
all_reduce = P.AllReduce().add_prim_attr("fusion", 1)
|
||||||
|
mul = P.Mul()
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_net(x1, x2, x3):
|
||||||
|
product = mul(x1, x2)
|
||||||
|
sum1 = add(x2, x3)
|
||||||
|
reduce1 = all_reduce(product)
|
||||||
|
reduce2 = all_reduce(sum1)
|
||||||
|
res = add(reduce1, reduce2)
|
||||||
|
return res
|
||||||
|
|
|
@ -59,4 +59,51 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
|
||||||
|
|
||||||
ASSERT_EQ(100, analyzer->LeastMemNeeded());
|
ASSERT_EQ(100, analyzer->LeastMemNeeded());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: MemUsageAnalyzer
|
||||||
|
/// Description: Test MemUsageAnalyzer interface with allreduce node
|
||||||
|
/// Expectation: Pass all interface test
|
||||||
|
TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer_fused_tesnor) {
|
||||||
|
auto net = get_py_func_("all_reduce_net");
|
||||||
|
EXPECT_NE(net, nullptr);
|
||||||
|
std::vector<int64_t> shp_x{2, 2, 2, 2};
|
||||||
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||||
|
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
|
||||||
|
|
||||||
|
auto func_graph = GetFuncGraph(net, args_spec_list);
|
||||||
|
auto kernel_graph = Compile(func_graph);
|
||||||
|
|
||||||
|
auto analyzer = std::make_shared<MemUsageAnalyzer>();
|
||||||
|
analyzer->Analyze(kernel_graph);
|
||||||
|
auto kernel_infos = analyzer->GetMemUsageKernelInfos();
|
||||||
|
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
|
||||||
|
|
||||||
|
ASSERT_EQ(4, kernel_infos.size());
|
||||||
|
ASSERT_EQ(14, tensor_infos.size());
|
||||||
|
ASSERT_EQ(260, analyzer->LeastMemNeeded());
|
||||||
|
|
||||||
|
size_t comm_kernel_num = 0;
|
||||||
|
for (size_t i = 0; i < kernel_infos.size(); ++i) {
|
||||||
|
auto kernel_info = analyzer->GetMemUsageKernelInfo(i);
|
||||||
|
ASSERT_NE(nullptr, kernel_info);
|
||||||
|
if (kernel_info->is_comm_) {
|
||||||
|
++comm_kernel_num;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(1, comm_kernel_num);
|
||||||
|
|
||||||
|
size_t tensor_max_used = 0;
|
||||||
|
size_t fused_tensor_size = 0;
|
||||||
|
for (size_t i = 0; i < tensor_infos.size(); ++i) {
|
||||||
|
auto tensor_info = analyzer->GetMemUsageTensorInfo(i);
|
||||||
|
ASSERT_NE(nullptr, tensor_info);
|
||||||
|
auto used_size = tensor_info->used_by_kernels_.size();
|
||||||
|
tensor_max_used = tensor_max_used < used_size ? used_size : tensor_max_used;
|
||||||
|
if (tensor_info->fused_tensor_ids_.size() > 0) {
|
||||||
|
++fused_tensor_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(2, tensor_max_used);
|
||||||
|
ASSERT_EQ(2, fused_tensor_size);
|
||||||
|
}
|
||||||
} // namespace mindspore::device
|
} // namespace mindspore::device
|
Loading…
Reference in New Issue