!24165 offline debug api design

Merge pull request !24165 from wenkai/wk0825_2
This commit is contained in:
i-robot 2021-10-04 01:31:46 +00:00 committed by Gitee
commit bc37faad4d
5 changed files with 513 additions and 0 deletions

View File

@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DebuggerTensor."""
from abc import ABC
class DebuggerTensor(ABC):
"""
The tensor with specific rank, iteration and debugging info.
Note:
- Users should not instantiate this class manually.
- The instances of this class is immutable.
- A DebuggerTensor is always the output tensor of a node.
"""
@property
def node(self):
"""
Get the node that outputs this tensor.
Returns:
Node, the node that outputs this tensor.
"""
return None
@property
def name(self):
"""
Get the name of this tensor.
The name is composed of full name of a node and the slot number.
Returns:
str, the name of this tensor.
"""
return ""
@property
def slot(self):
"""
Get slot.
Returns:
int, the slot of the tensor on the node.
"""
return -1
@property
def iteration(self):
"""
Get the iteration for this tensor.
Returns:
int, the iteration for this tensor.
"""
return -1
@property
def rank(self):
"""
Get the rank for this tensor.
Returns:
int, the rank for this tensor.
"""
return -1
def get_value(self):
"""
Get the value of the tensor.
Returns:
numpy.ndarray, the value of the debugger tensor.
"""
def get_affected_nodes(self):
"""
Get the nodes that use current tensor as input.
"""

View File

@ -0,0 +1,138 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Debugger python API."""
from typing import Iterable
from mindspore.offline_debug.debugger_tensor import DebuggerTensor
from mindspore.offline_debug.node import Node
from mindspore.offline_debug.watchpoints import WatchpointBase, WatchpointHit
class DumpAnalyzer:
"""
Analyzer to inspect the dump data.
Args:
summary_dir (str): The path of the summary directory which contains
dump folder.
mem_limit (int, optional): The memory limit for this debugger session in
MB. Default: None, which means no limit.
"""
def __init__(self, summary_dir, mem_limit=None):
self._summary_dir = summary_dir
self._mem_limit = mem_limit
def export_graphs(self, output_dir=None):
"""
Export the computational graph(s) in xlsx file(s) to the output_dir.
The file(s) will contain the stack info of graph nodes.
Args:
output_dir (str, optional): Output directory to save the file.
Default: None, which means to use the current working directory.
Returns:
str. The path of the generated file.
"""
def select_nodes(
self,
query_string,
use_regex=False,
match_target="name",
case_sensitive=True) -> Iterable[Node]:
"""
Select nodes.
Args:
query_string (str): Query string. For a node to be selected, the
match target field must contains or matches the query string.
use_regex (bool): Indicates whether query is a regex. Default: False.
match_target (str, optional): The field to search when selecting
nodes. Available values are "name", "stack".
"name" means to search the name of the nodes in the
graph. "stack" means the stack info of
the node. Default: "name".
case_sensitive (bool, optional): Whether case-sensitive when
selecting tensors. Default: True.
Returns:
Iterable[Node], the matched nodes.
"""
def select_tensors(
self,
query_string,
use_regex=False,
match_target="name",
iterations=None,
ranks=None,
slots=None,
case_sensitive=True) -> Iterable[DebuggerTensor]:
"""
Select tensors.
Args:
query_string (str): Query string. For a tensor to be selected, the
match target field must contains or matches the query string.
use_regex (bool): Indicates whether query is a regex. Default: False.
match_target (str, optional): The field to search when selecting
tensors. Available values are "name", "stack".
"name" means to search the name of the tensors in the
graph. "name" is composed of graph node's full_name
and the tensor's slot number. "stack" means the stack info of
the node that outputs this tensor. Default: "name".
iterations (list[int], optional): The iterations to select. Default:
None, which means all iterations will be selected.
ranks (list(int], optional): The ranks to select. Default: None,
which means all ranks will be selected.
slots (list[int], optional): The slot of the selected tensor.
Default: None, which means all slots will be selected.
case_sensitive (bool, optional): Whether case-sensitive when
selecting tensors. Default: True.
Returns:
Iterable[DebuggerTensor], the matched tensors.
"""
def get_iterations(self) -> Iterable[int]:
"""Get the available iterations this run."""
def get_ranks(self) -> Iterable[int]:
"""Get the available ranks in this run."""
def check_watchpoints(
self,
watchpoints: Iterable[WatchpointBase]) -> Iterable[WatchpointHit]:
"""
Check the given watch points on specified nodes(if available) on the
given iterations(if available) in a batch.
Note:
For speed, all watchpoints for the iteration should be given at
the same time to avoid reading tensors len(watchpoints) times.
Args:
watchpoints (Iterable[WatchpointBase]): The list of watchpoints.
Returns:
Iterable[WatchpointHit], the watchpoint hist list is carefully
sorted so that the user can see the most import hit on the
top of the list. When there are many many watchpoint hits,
we will display the list in a designed clear way.
"""

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Node in the computational graph."""
from abc import ABC
class Node(ABC):
"""Node in the computational graph."""
@property
def name(self):
"""
Get the full name of this node.
Returns:
str, the full name of the node.
"""
return ""
@property
def stack(self):
"""Get stack info."""
return None
def get_input_tensors(
self,
iterations=None,
ranks=None,
slots=None):
"""
Get the input tensors of the node.
Returns:
Iterable[DebuggerTensor], the input tensors of the node.
"""
def get_output_tensors(
self,
iterations=None,
ranks=None,
slots=None):
"""
Get the output tensors of this node.
Returns:
Iterable[DebuggerTensor], the output tensors of the node.
"""
def get_input_nodes(self):
"""
Get the input nodes of this node.
Returns:
Iterable[Node], the input nodes of this node.
"""
def get_output_nodes(self):
"""
Get the nodes that use the output tensors of this node.
Returns:
Iterable[Node], the output nodes of this node.
"""

View File

@ -0,0 +1,128 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Watchpoints."""
from mindspore.offline_debug.debugger_tensor import DebuggerTensor
class WatchpointBase:
"""
Base class for watchpoints.
Note:
- The watchpoint is bounded with tensor names.
- If multiple checking items is specified for one watch point instance,
a tensor needs to trigger all of them to trigger the watchpoint.
"""
@property
def name(self):
"""Get the name for the watchpoint."""
raise NotImplementedError
def check(self):
"""
Check the watchpoint against the tensors.
Returns:
list[WatchpointHit], the hits of the watchpoint.
"""
class WatchpointHit:
"""
Watchpoint hit.
Note:
- This class is not meant to be instantiated by user.
- The instances of this class is immutable.
Args:
tensor (DebuggerTensor): The tensor which hits the watchpoint.
watchpoint (WatchpointBase): The WatchPointBase object initialized with
user setting value.
watchpoint_hit_detail (WatchpointBase): The WatchPointBase object
initialized with actual value of the Tensor.
error_code: The code describing error.
"""
def __init__(self,
tensor: DebuggerTensor,
watchpoint: WatchpointBase,
watchpoint_hit_detail: WatchpointBase,
error_code):
self._tensor = tensor
self._watchpoint = watchpoint
self._error_code = error_code
self._watchpoint_hit_detail = watchpoint_hit_detail
def __str__(self):
if self._error_code:
return f"Watchpoint {self._watchpoint.name} check failed " \
f"on tensor {self._tensor.name}. " \
f"Error detail: error detail."
return f"Watchpoint {self._watchpoint.name} triggered on " \
f"tensor {self._tensor.name}. " \
f"The setting for watchpoint is mean_gt=0.2, abs_mean_gt=0.3." \
f"The actual value of the tensor is " \
f"mean_gt=0.21, abs_mean_gt=0.35."
@property
def tensor(self) -> DebuggerTensor:
"""Get the tensor for this watchpoint hit."""
return self._tensor
def get_watchpoint(self):
"""Get the original watchpoint."""
return self._watchpoint
def get_hit_detail(self):
"""Get the actual values for the thresholds in the watchpoint."""
return self._watchpoint_hit_detail
class TensorTooLargeWatchpoint(WatchpointBase):
"""
Tensor too large watchpoint.
When all specified checking conditions were satisfied, this watchpoint would
be hit after a check.
Args:
tensors (Iterable[DebuggerTensor]): The tensors to check.
abs_mean_gt (float, optional): The threshold for mean of the absolute
value of the tensor. When the actual value was greater than this
threshold, this checking condition would be satisfied.
max_gt (float, optional): The threshold for maximum of the tensor. When
the actual value was greater than this threshold, this checking
condition would be satisfied.
min_gt (float, optional): The threshold for minimum of the tensor. When
the actual value was greater than this threshold, this checking
condition would be satisfied.
mean_gt (float, optional): The threshold for mean of the tensor. When
the actual value was greater than this threshold, this checking
condition would be satisfied.
"""
def __init__(self, tensors,
abs_mean_gt=None, max_gt=None, min_gt=None, mean_gt=None):
self._tensors = tensors
self._abs_mean_gt = abs_mean_gt
self._max_gt = max_gt
self._min_gt = min_gt
self._mean_gt = mean_gt
@property
def name(self):
return "TensorTooLarge"

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test debug API."""
import pytest
from mindspore.offline_debug.dump_analyzer import DumpAnalyzer
from mindspore.offline_debug.watchpoints import TensorTooLargeWatchpoint
@pytest.mark.skip(reason="Feature under development.")
def test_export_graphs():
"""Test debug API."""
my_run = DumpAnalyzer(
summary_dir="/path/to/summary-dir1"
)
# Export the info about computational graph. Should support multi graphs.
my_run.export_graphs()
@pytest.mark.skip(reason="Feature under development.")
def test_select_tensors():
"""Test debug API."""
my_run = DumpAnalyzer(
summary_dir="/path/to/summary-dir2"
)
# Find the interested tensors.
matched_tensors = my_run.select_tensors(".*conv1.*", use_regex=True)
assert matched_tensors == []
@pytest.mark.skip(reason="Feature under development.")
def test_check_watchpoints_all_iterations():
"""Test debug API."""
my_run = DumpAnalyzer(
summary_dir="/path/to/summary-dir3"
)
# Checking all the iterations.
watchpoints = [
TensorTooLargeWatchpoint(
tensors=my_run.select_tensors(
"(*.weight^)|(*.bias^)", use_regex=True),
abs_mean_gt=0.1)
]
watch_point_hits = my_run.check_watchpoints(watchpoints=watchpoints)
assert watch_point_hits == []
@pytest.mark.skip(reason="Feature under development.")
def test_check_watchpoints_one_iteration():
"""Test debug API."""
my_run = DumpAnalyzer(
summary_dir="/path/to/summary-dir4"
)
# Checking specific iteration.
watchpoints = [
TensorTooLargeWatchpoint(
tensors=my_run.select_tensors(
"(*.weight^)|(*.bias^)", use_regex=True,
iterations=[1]),
abs_mean_gt=0.1)
]
watch_point_hits = my_run.check_watchpoints(watchpoints=watchpoints)
assert watch_point_hits == []