add warnings for set_dump API, see I4AUIR for details

This commit is contained in:
wenkai 2021-11-20 19:16:29 +08:00
parent 875f35d6d8
commit c493f19b6b
2 changed files with 47 additions and 1 deletions

View File

@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
"""Controlling dump behavior."""
from warnings import warn
import mindspore.context as context
from mindspore._c_expression import security
@ -74,6 +76,25 @@ def set_dump(target, enabled=True):
if not isinstance(enabled, bool):
raise ValueError("The \"enabled\" parameter must be bool.")
# Checking for device target and mode.
current_target = context.get_context("device_target")
if current_target != "Ascend":
# We will not return here in case user changed device_target later.
warn("Current device_target is {}, which is not supported by set_dump. "
"Only Ascend device target is supported currently. "
"If you have Ascend device, consider set device_target to Ascend "
"before calling set_dump.".format(current_target))
current_mode = context.get_context("mode")
if current_mode != context.GRAPH_MODE:
# We will not return here in case user changed mode later.
warn(
"Current mode is PYNATIVE_MODE, which is not supported by set_dump. "
"Only GRAPH_MODE is supported currently. "
"Consider set mode to GRAPH_MODE "
"before calling set_dump.")
# The actual set dump logic.
mode = "true" if enabled else "false"
if isinstance(target, nn.Cell):
primitives = getattr(target, "_primitives", {})

View File

@ -13,8 +13,11 @@
# limitations under the License.
# ============================================================================
"""Test dump."""
import warnings
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import set_dump
@ -26,6 +29,7 @@ def test_set_dump_on_cell():
Description: Use set_dump API on Cell instance.
Expectation: Success.
"""
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
@ -38,7 +42,9 @@ def test_set_dump_on_cell():
return x
net = MyNet()
set_dump(net.conv1)
set_dump(net.relu1)
assert net.relu1.relu.attrs["dump"] == "true"
def test_set_dump_on_primitive():
@ -49,6 +55,7 @@ def test_set_dump_on_primitive():
"""
op = ops.Add()
set_dump(op)
assert op.attrs["dump"] == "true"
def test_input_type_check():
@ -59,3 +66,21 @@ def test_input_type_check():
"""
with pytest.raises(ValueError):
set_dump(1)
@pytest.mark.skip(reason="Warning can only be triggered once, please execute "
"this test case manually.")
def test_set_dump_warning():
"""
Feature: Python API set_dump.
Description: Test the warning about device target and mode.
Expectation: Trigger warning message.
"""
context.set_context(device_target="CPU")
context.set_context(mode=context.PYNATIVE_MODE)
op = ops.Add()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
set_dump(op)
assert "Only Ascend device target is supported" in str(w[-2].message)
assert "Only GRAPH_MODE is supported" in str(w[-1].message)