!26127 Add cell dump option when dump_mode=2

Merge pull request !26127 from sabrinasun_59ee/cell
This commit is contained in:
i-robot 2021-11-17 01:58:46 +00:00 committed by Gitee
commit 605c07d898
4 changed files with 224 additions and 6 deletions

View File

@ -315,10 +315,17 @@ void CheckJsonArrayType(const nlohmann::json &content, const std::string &key) {
}
void DumpJsonParser::ParseDumpMode(const nlohmann::json &content) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
CheckJsonUnsignedType(content, kDumpMode);
dump_mode_ = content;
if (dump_mode_ != 0 && dump_mode_ != 1) {
MS_LOG(EXCEPTION) << "Dump config parse failed, dump_mode should be 0 or 1, but got " << dump_mode_;
if (dump_mode_ < DUMP_ALL || dump_mode_ > DUMP_CELL) {
MS_LOG(EXCEPTION) << "Dump config parse failed, dump_mode should be 0, 1 or 2, but got " << dump_mode_;
}
if (dump_mode_ == DUMP_CELL) {
if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice || e2e_dump_enabled_) {
MS_LOG(EXCEPTION) << "Cell dump is only supported in Ascend async dump. Please set dump_mode to 0 or 1.";
}
}
}
@ -546,11 +553,23 @@ void DumpJsonParser::JudgeDumpEnabled() {
}
bool DumpJsonParser::NeedDump(const std::string &op_full_name) const {
if (dump_mode_ == 0) {
return true;
bool need_dump = false;
switch (dump_mode_) {
case DUMP_ALL:
need_dump = true;
break;
case DUMP_KERNEL:
if (kernels_.find(op_full_name) != kernels_.end()) {
need_dump = true;
}
break;
case DUMP_CELL:
if (std::find(cell_dump_kernels_.begin(), cell_dump_kernels_.end(), op_full_name) != cell_dump_kernels_.end()) {
need_dump = true;
}
break;
}
auto iter = kernels_.find(op_full_name);
return iter != kernels_.end();
return need_dump;
}
void DumpJsonParser::MatchKernel(const std::string &kernel_name) {
@ -610,10 +629,29 @@ bool DumpJsonParser::OutputNeedDump() const {
return input_output_ == kDumpInputAndOutput || input_output_ == kDumpOutputOnly;
}
void DumpJsonParser::GetCellDumpFlag(const session::KernelGraph &kernel_graph) {
if (dump_mode_ != 2) {
return;
}
for (const auto &kernel : kernel_graph.execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
auto dump_flag = AnfAlgo::GetDumpFlag(kernel);
if (!dump_flag) {
continue;
}
MS_LOG(INFO) << "Dump flag is true for " << GetKernelNodeName(kernel);
cell_dump_kernels_.push_back(GetKernelNodeName(kernel));
}
}
void DumpJsonParser::UpdateNeedDumpKernels(const session::KernelGraph &kernel_graph) {
if (!async_dump_enabled_) {
return;
}
MS_LOG(INFO) << "Get async kernel dump flag";
GetCellDumpFlag(kernel_graph);
MS_LOG(INFO) << "Update async dump kernel list for hccl";
std::map<std::string, uint32_t> update_kernels;
for (const auto &kernel : kernel_graph.execution_order()) {

View File

@ -63,11 +63,13 @@ class DumpJsonParser {
bool InputNeedDump() const;
bool OutputNeedDump() const;
std::string GetOpOverflowBinPath(uint32_t graph_id) const;
void GetCellDumpFlag(const session::KernelGraph &kernel_graph);
void UpdateNeedDumpKernels(const session::KernelGraph &kernel_graph);
void ClearGraph() { graphs_.clear(); }
void SaveGraph(session::KernelGraph *graph) { (void)graphs_.emplace_back(graph); }
const std::vector<session::KernelGraph *> &graphs() const { return graphs_; }
enum JsonDumpMode { DUMP_ALL = 0, DUMP_KERNEL = 1, DUMP_CELL = 2 };
private:
DumpJsonParser() = default;
@ -84,6 +86,7 @@ class DumpJsonParser {
std::string iteration_;
uint32_t input_output_{0};
std::map<std::string, uint32_t> kernels_;
std::vector<std::string> cell_dump_kernels_;
std::set<uint32_t> support_devices_;
uint32_t op_debug_mode_{0};
bool trans_flag_{false};

View File

@ -154,6 +154,21 @@ def generate_statistic_dump_json(dump_path, json_file_name, test_key, saved_data
with open(json_file_name, 'w') as f:
json.dump(data, f)
def generate_cell_dump_json(dump_path, json_file_name, test_key, dump_mode):
"""
Util function to generate dump configuration json file.
"""
if test_key == "test_async_dump":
data = async_dump_dict
data["common_dump_settings"]["path"] = dump_path
data["common_dump_settings"]["dump_mode"] = dump_mode
else:
raise ValueError(
"Failed to generate dump json file. Overflow only support in async dump")
with open(json_file_name, 'w') as f:
json.dump(data, f)
def check_dump_structure(dump_path, json_file_path, num_card, num_graph, num_iteration):
"""
Util to check if the dump structure is correct.

View File

@ -0,0 +1,162 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import sys
import tempfile
import time
import shutil
import glob
import numpy as np
import pytest
from mindspore import Tensor, set_dump
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore.nn import Dense
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell
from mindspore.nn import WithLossCell
from dump_test_utils import generate_cell_dump_json, check_dump_structure
from tests.security_utils import security_off_wrap
class ReluReduceMeanDenseRelu(Cell):
def __init__(self, kernel, bias, in_channel, num_class):
super().__init__()
self.relu = P.ReLU()
self.mean = P.ReduceMean(keep_dims=False)
self.dense = Dense(in_channel, num_class, kernel, bias)
def construct(self, x_):
x_ = self.relu(x_)
x_ = self.mean(x_, (2, 3))
x_ = self.dense(x_)
x_ = self.relu(x_)
return x_
def run_multi_layer_train(is_set_dump):
weight = Tensor(np.ones((1000, 2048)).astype(np.float32))
bias = Tensor(np.ones((1000,)).astype(np.float32))
net = ReluReduceMeanDenseRelu(weight, bias, 2048, 1000)
if is_set_dump:
set_dump(net.relu)
criterion = SoftmaxCrossEntropyWithLogits(sparse=False)
optimizer = Momentum(learning_rate=0.1, momentum=0.1,
params=filter(lambda x: x.requires_grad, net.get_parameters()))
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()
inputs = Tensor(np.random.randn(32, 2048, 7, 7).astype(np.float32))
label = Tensor(np.zeros(shape=(32, 1000)).astype(np.float32))
train_network(inputs, label)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_ascend_cell_dump():
"""
Feature: Cell Dump
Description: Test cell dump
Expectation: Only dump cell set by set_dump when dump_mode = 2
"""
if sys.platform != 'linux':
return
with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
dump_path = os.path.join(tmp_dir, 'cell_dump')
dump_config_path = os.path.join(tmp_dir, 'cell_dump.json')
generate_cell_dump_json(dump_path, dump_config_path, 'test_async_dump', 2)
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
run_multi_layer_train(True)
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
for _ in range(5):
if not os.path.exists(dump_file_path):
time.sleep(2)
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
# make sure 2 relu dump files are generated with correct name prefix
assert len(os.listdir(dump_file_path)) == 2
relu_file_name = "ReLU.Default_network-WithLossCell__backbone-ReluReduceMeanDenseRelu_ReLU-op*.*.*.*"
relu_file1 = glob.glob(os.path.join(dump_file_path, relu_file_name))[0]
relu_file2 = glob.glob(os.path.join(dump_file_path, relu_file_name))[1]
assert relu_file1
assert relu_file2
del os.environ['MINDSPORE_DUMP_CONFIG']
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_ascend_not_cell_dump():
"""
Feature: Cell Dump
Description: Test cell dump
Expectation: Should ignore set_dump when dump_mode != 2
"""
if sys.platform != 'linux':
return
with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
dump_path = os.path.join(tmp_dir, 'cell_dump')
dump_config_path = os.path.join(tmp_dir, 'cell_dump.json')
generate_cell_dump_json(dump_path, dump_config_path, 'test_async_dump', 0)
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
run_multi_layer_train(True)
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
for _ in range(5):
if not os.path.exists(dump_file_path):
time.sleep(2)
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
# make sure set_dump is ignored and all cell layer are dumped
assert len(os.listdir(dump_file_path)) == 10
del os.environ['MINDSPORE_DUMP_CONFIG']
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_ascend_cell_empty_dump():
"""
Feature: Cell Dump
Description: Test cell dump
Expectation: Should dump nothing when set_dump is not set and dump_mode = 2
"""
if sys.platform != 'linux':
return
with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
dump_path = os.path.join(tmp_dir, 'cell_dump')
dump_config_path = os.path.join(tmp_dir, 'cell_dump.json')
generate_cell_dump_json(dump_path, dump_config_path, 'test_async_dump', 2)
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
run_multi_layer_train(False)
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net')
time.sleep(5)
# make sure set_dump is ignored and all cell layer are dumped
assert not os.path.exists(dump_file_path)
del os.environ['MINDSPORE_DUMP_CONFIG']