forked from mindspore-Ecosystem/mindspore
!24165 offline debug api design
Merge pull request !24165 from wenkai/wk0825_2
This commit is contained in:
commit
bc37faad4d
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""DebuggerTensor."""
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class DebuggerTensor(ABC):
|
||||
"""
|
||||
The tensor with specific rank, iteration and debugging info.
|
||||
|
||||
Note:
|
||||
- Users should not instantiate this class manually.
|
||||
- The instances of this class is immutable.
|
||||
- A DebuggerTensor is always the output tensor of a node.
|
||||
"""
|
||||
@property
|
||||
def node(self):
|
||||
"""
|
||||
Get the node that outputs this tensor.
|
||||
|
||||
Returns:
|
||||
Node, the node that outputs this tensor.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
Get the name of this tensor.
|
||||
|
||||
The name is composed of full name of a node and the slot number.
|
||||
|
||||
Returns:
|
||||
str, the name of this tensor.
|
||||
"""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def slot(self):
|
||||
"""
|
||||
Get slot.
|
||||
|
||||
Returns:
|
||||
int, the slot of the tensor on the node.
|
||||
"""
|
||||
return -1
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""
|
||||
Get the iteration for this tensor.
|
||||
|
||||
Returns:
|
||||
int, the iteration for this tensor.
|
||||
"""
|
||||
return -1
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
Get the rank for this tensor.
|
||||
|
||||
Returns:
|
||||
int, the rank for this tensor.
|
||||
|
||||
"""
|
||||
return -1
|
||||
|
||||
def get_value(self):
|
||||
"""
|
||||
Get the value of the tensor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, the value of the debugger tensor.
|
||||
"""
|
||||
|
||||
def get_affected_nodes(self):
|
||||
"""
|
||||
Get the nodes that use current tensor as input.
|
||||
"""
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Debugger python API."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from mindspore.offline_debug.debugger_tensor import DebuggerTensor
|
||||
from mindspore.offline_debug.node import Node
|
||||
from mindspore.offline_debug.watchpoints import WatchpointBase, WatchpointHit
|
||||
|
||||
|
||||
class DumpAnalyzer:
|
||||
"""
|
||||
Analyzer to inspect the dump data.
|
||||
|
||||
Args:
|
||||
summary_dir (str): The path of the summary directory which contains
|
||||
dump folder.
|
||||
mem_limit (int, optional): The memory limit for this debugger session in
|
||||
MB. Default: None, which means no limit.
|
||||
"""
|
||||
|
||||
def __init__(self, summary_dir, mem_limit=None):
|
||||
self._summary_dir = summary_dir
|
||||
self._mem_limit = mem_limit
|
||||
|
||||
def export_graphs(self, output_dir=None):
|
||||
"""
|
||||
Export the computational graph(s) in xlsx file(s) to the output_dir.
|
||||
|
||||
The file(s) will contain the stack info of graph nodes.
|
||||
|
||||
Args:
|
||||
output_dir (str, optional): Output directory to save the file.
|
||||
Default: None, which means to use the current working directory.
|
||||
|
||||
Returns:
|
||||
str. The path of the generated file.
|
||||
"""
|
||||
|
||||
def select_nodes(
|
||||
self,
|
||||
query_string,
|
||||
use_regex=False,
|
||||
match_target="name",
|
||||
case_sensitive=True) -> Iterable[Node]:
|
||||
"""
|
||||
Select nodes.
|
||||
|
||||
Args:
|
||||
query_string (str): Query string. For a node to be selected, the
|
||||
match target field must contains or matches the query string.
|
||||
use_regex (bool): Indicates whether query is a regex. Default: False.
|
||||
match_target (str, optional): The field to search when selecting
|
||||
nodes. Available values are "name", "stack".
|
||||
"name" means to search the name of the nodes in the
|
||||
graph. "stack" means the stack info of
|
||||
the node. Default: "name".
|
||||
case_sensitive (bool, optional): Whether case-sensitive when
|
||||
selecting tensors. Default: True.
|
||||
|
||||
Returns:
|
||||
Iterable[Node], the matched nodes.
|
||||
"""
|
||||
|
||||
def select_tensors(
|
||||
self,
|
||||
query_string,
|
||||
use_regex=False,
|
||||
match_target="name",
|
||||
iterations=None,
|
||||
ranks=None,
|
||||
slots=None,
|
||||
case_sensitive=True) -> Iterable[DebuggerTensor]:
|
||||
"""
|
||||
Select tensors.
|
||||
|
||||
Args:
|
||||
query_string (str): Query string. For a tensor to be selected, the
|
||||
match target field must contains or matches the query string.
|
||||
use_regex (bool): Indicates whether query is a regex. Default: False.
|
||||
match_target (str, optional): The field to search when selecting
|
||||
tensors. Available values are "name", "stack".
|
||||
"name" means to search the name of the tensors in the
|
||||
graph. "name" is composed of graph node's full_name
|
||||
and the tensor's slot number. "stack" means the stack info of
|
||||
the node that outputs this tensor. Default: "name".
|
||||
iterations (list[int], optional): The iterations to select. Default:
|
||||
None, which means all iterations will be selected.
|
||||
ranks (list(int], optional): The ranks to select. Default: None,
|
||||
which means all ranks will be selected.
|
||||
slots (list[int], optional): The slot of the selected tensor.
|
||||
Default: None, which means all slots will be selected.
|
||||
case_sensitive (bool, optional): Whether case-sensitive when
|
||||
selecting tensors. Default: True.
|
||||
|
||||
Returns:
|
||||
Iterable[DebuggerTensor], the matched tensors.
|
||||
"""
|
||||
|
||||
def get_iterations(self) -> Iterable[int]:
|
||||
"""Get the available iterations this run."""
|
||||
|
||||
def get_ranks(self) -> Iterable[int]:
|
||||
"""Get the available ranks in this run."""
|
||||
|
||||
def check_watchpoints(
|
||||
self,
|
||||
watchpoints: Iterable[WatchpointBase]) -> Iterable[WatchpointHit]:
|
||||
"""
|
||||
Check the given watch points on specified nodes(if available) on the
|
||||
given iterations(if available) in a batch.
|
||||
|
||||
Note:
|
||||
For speed, all watchpoints for the iteration should be given at
|
||||
the same time to avoid reading tensors len(watchpoints) times.
|
||||
|
||||
Args:
|
||||
watchpoints (Iterable[WatchpointBase]): The list of watchpoints.
|
||||
|
||||
Returns:
|
||||
Iterable[WatchpointHit], the watchpoint hist list is carefully
|
||||
sorted so that the user can see the most import hit on the
|
||||
top of the list. When there are many many watchpoint hits,
|
||||
we will display the list in a designed clear way.
|
||||
"""
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Node in the computational graph."""
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class Node(ABC):
|
||||
"""Node in the computational graph."""
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
Get the full name of this node.
|
||||
|
||||
Returns:
|
||||
str, the full name of the node.
|
||||
"""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def stack(self):
|
||||
"""Get stack info."""
|
||||
return None
|
||||
|
||||
def get_input_tensors(
|
||||
self,
|
||||
iterations=None,
|
||||
ranks=None,
|
||||
slots=None):
|
||||
"""
|
||||
Get the input tensors of the node.
|
||||
|
||||
Returns:
|
||||
Iterable[DebuggerTensor], the input tensors of the node.
|
||||
"""
|
||||
|
||||
def get_output_tensors(
|
||||
self,
|
||||
iterations=None,
|
||||
ranks=None,
|
||||
slots=None):
|
||||
"""
|
||||
Get the output tensors of this node.
|
||||
|
||||
Returns:
|
||||
Iterable[DebuggerTensor], the output tensors of the node.
|
||||
"""
|
||||
|
||||
def get_input_nodes(self):
|
||||
"""
|
||||
Get the input nodes of this node.
|
||||
|
||||
Returns:
|
||||
Iterable[Node], the input nodes of this node.
|
||||
|
||||
"""
|
||||
|
||||
def get_output_nodes(self):
|
||||
"""
|
||||
Get the nodes that use the output tensors of this node.
|
||||
|
||||
Returns:
|
||||
Iterable[Node], the output nodes of this node.
|
||||
"""
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Watchpoints."""
|
||||
from mindspore.offline_debug.debugger_tensor import DebuggerTensor
|
||||
|
||||
|
||||
class WatchpointBase:
|
||||
"""
|
||||
Base class for watchpoints.
|
||||
|
||||
Note:
|
||||
- The watchpoint is bounded with tensor names.
|
||||
- If multiple checking items is specified for one watch point instance,
|
||||
a tensor needs to trigger all of them to trigger the watchpoint.
|
||||
"""
|
||||
@property
|
||||
def name(self):
|
||||
"""Get the name for the watchpoint."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Check the watchpoint against the tensors.
|
||||
|
||||
Returns:
|
||||
list[WatchpointHit], the hits of the watchpoint.
|
||||
"""
|
||||
|
||||
|
||||
class WatchpointHit:
|
||||
"""
|
||||
Watchpoint hit.
|
||||
|
||||
Note:
|
||||
- This class is not meant to be instantiated by user.
|
||||
- The instances of this class is immutable.
|
||||
|
||||
Args:
|
||||
tensor (DebuggerTensor): The tensor which hits the watchpoint.
|
||||
watchpoint (WatchpointBase): The WatchPointBase object initialized with
|
||||
user setting value.
|
||||
watchpoint_hit_detail (WatchpointBase): The WatchPointBase object
|
||||
initialized with actual value of the Tensor.
|
||||
error_code: The code describing error.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tensor: DebuggerTensor,
|
||||
watchpoint: WatchpointBase,
|
||||
watchpoint_hit_detail: WatchpointBase,
|
||||
error_code):
|
||||
self._tensor = tensor
|
||||
self._watchpoint = watchpoint
|
||||
self._error_code = error_code
|
||||
self._watchpoint_hit_detail = watchpoint_hit_detail
|
||||
|
||||
def __str__(self):
|
||||
if self._error_code:
|
||||
return f"Watchpoint {self._watchpoint.name} check failed " \
|
||||
f"on tensor {self._tensor.name}. " \
|
||||
f"Error detail: error detail."
|
||||
|
||||
return f"Watchpoint {self._watchpoint.name} triggered on " \
|
||||
f"tensor {self._tensor.name}. " \
|
||||
f"The setting for watchpoint is mean_gt=0.2, abs_mean_gt=0.3." \
|
||||
f"The actual value of the tensor is " \
|
||||
f"mean_gt=0.21, abs_mean_gt=0.35."
|
||||
|
||||
@property
|
||||
def tensor(self) -> DebuggerTensor:
|
||||
"""Get the tensor for this watchpoint hit."""
|
||||
return self._tensor
|
||||
|
||||
def get_watchpoint(self):
|
||||
"""Get the original watchpoint."""
|
||||
return self._watchpoint
|
||||
|
||||
def get_hit_detail(self):
|
||||
"""Get the actual values for the thresholds in the watchpoint."""
|
||||
return self._watchpoint_hit_detail
|
||||
|
||||
|
||||
class TensorTooLargeWatchpoint(WatchpointBase):
|
||||
"""
|
||||
Tensor too large watchpoint.
|
||||
|
||||
When all specified checking conditions were satisfied, this watchpoint would
|
||||
be hit after a check.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[DebuggerTensor]): The tensors to check.
|
||||
abs_mean_gt (float, optional): The threshold for mean of the absolute
|
||||
value of the tensor. When the actual value was greater than this
|
||||
threshold, this checking condition would be satisfied.
|
||||
max_gt (float, optional): The threshold for maximum of the tensor. When
|
||||
the actual value was greater than this threshold, this checking
|
||||
condition would be satisfied.
|
||||
min_gt (float, optional): The threshold for minimum of the tensor. When
|
||||
the actual value was greater than this threshold, this checking
|
||||
condition would be satisfied.
|
||||
mean_gt (float, optional): The threshold for mean of the tensor. When
|
||||
the actual value was greater than this threshold, this checking
|
||||
condition would be satisfied.
|
||||
"""
|
||||
|
||||
def __init__(self, tensors,
|
||||
abs_mean_gt=None, max_gt=None, min_gt=None, mean_gt=None):
|
||||
self._tensors = tensors
|
||||
self._abs_mean_gt = abs_mean_gt
|
||||
self._max_gt = max_gt
|
||||
self._min_gt = min_gt
|
||||
self._mean_gt = mean_gt
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "TensorTooLarge"
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test debug API."""
|
||||
import pytest
|
||||
|
||||
from mindspore.offline_debug.dump_analyzer import DumpAnalyzer
|
||||
from mindspore.offline_debug.watchpoints import TensorTooLargeWatchpoint
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature under development.")
|
||||
def test_export_graphs():
|
||||
"""Test debug API."""
|
||||
my_run = DumpAnalyzer(
|
||||
summary_dir="/path/to/summary-dir1"
|
||||
)
|
||||
|
||||
# Export the info about computational graph. Should support multi graphs.
|
||||
my_run.export_graphs()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature under development.")
|
||||
def test_select_tensors():
|
||||
"""Test debug API."""
|
||||
my_run = DumpAnalyzer(
|
||||
summary_dir="/path/to/summary-dir2"
|
||||
)
|
||||
|
||||
# Find the interested tensors.
|
||||
matched_tensors = my_run.select_tensors(".*conv1.*", use_regex=True)
|
||||
assert matched_tensors == []
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature under development.")
|
||||
def test_check_watchpoints_all_iterations():
|
||||
"""Test debug API."""
|
||||
my_run = DumpAnalyzer(
|
||||
summary_dir="/path/to/summary-dir3"
|
||||
)
|
||||
|
||||
# Checking all the iterations.
|
||||
watchpoints = [
|
||||
TensorTooLargeWatchpoint(
|
||||
tensors=my_run.select_tensors(
|
||||
"(*.weight^)|(*.bias^)", use_regex=True),
|
||||
abs_mean_gt=0.1)
|
||||
]
|
||||
|
||||
watch_point_hits = my_run.check_watchpoints(watchpoints=watchpoints)
|
||||
assert watch_point_hits == []
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature under development.")
|
||||
def test_check_watchpoints_one_iteration():
|
||||
"""Test debug API."""
|
||||
my_run = DumpAnalyzer(
|
||||
summary_dir="/path/to/summary-dir4"
|
||||
)
|
||||
# Checking specific iteration.
|
||||
watchpoints = [
|
||||
TensorTooLargeWatchpoint(
|
||||
tensors=my_run.select_tensors(
|
||||
"(*.weight^)|(*.bias^)", use_regex=True,
|
||||
iterations=[1]),
|
||||
abs_mean_gt=0.1)
|
||||
]
|
||||
|
||||
watch_point_hits = my_run.check_watchpoints(watchpoints=watchpoints)
|
||||
assert watch_point_hits == []
|
Loading…
Reference in New Issue