forked from mindspore-Ecosystem/mindspore
add warnings for set_dump API, see I4AUIR for details
This commit is contained in:
parent
875f35d6d8
commit
c493f19b6b
|
@ -13,7 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Controlling dump behavior."""
|
||||
from warnings import warn
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore._c_expression import security
|
||||
|
||||
|
||||
|
@ -74,6 +76,25 @@ def set_dump(target, enabled=True):
|
|||
if not isinstance(enabled, bool):
|
||||
raise ValueError("The \"enabled\" parameter must be bool.")
|
||||
|
||||
# Checking for device target and mode.
|
||||
current_target = context.get_context("device_target")
|
||||
if current_target != "Ascend":
|
||||
# We will not return here in case user changed device_target later.
|
||||
warn("Current device_target is {}, which is not supported by set_dump. "
|
||||
"Only Ascend device target is supported currently. "
|
||||
"If you have Ascend device, consider set device_target to Ascend "
|
||||
"before calling set_dump.".format(current_target))
|
||||
|
||||
current_mode = context.get_context("mode")
|
||||
if current_mode != context.GRAPH_MODE:
|
||||
# We will not return here in case user changed mode later.
|
||||
warn(
|
||||
"Current mode is PYNATIVE_MODE, which is not supported by set_dump. "
|
||||
"Only GRAPH_MODE is supported currently. "
|
||||
"Consider set mode to GRAPH_MODE "
|
||||
"before calling set_dump.")
|
||||
|
||||
# The actual set dump logic.
|
||||
mode = "true" if enabled else "false"
|
||||
if isinstance(target, nn.Cell):
|
||||
primitives = getattr(target, "_primitives", {})
|
||||
|
|
|
@ -13,8 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test dump."""
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import set_dump
|
||||
|
@ -26,6 +29,7 @@ def test_set_dump_on_cell():
|
|||
Description: Use set_dump API on Cell instance.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyNet, self).__init__()
|
||||
|
@ -38,7 +42,9 @@ def test_set_dump_on_cell():
|
|||
return x
|
||||
|
||||
net = MyNet()
|
||||
set_dump(net.conv1)
|
||||
set_dump(net.relu1)
|
||||
|
||||
assert net.relu1.relu.attrs["dump"] == "true"
|
||||
|
||||
|
||||
def test_set_dump_on_primitive():
|
||||
|
@ -49,6 +55,7 @@ def test_set_dump_on_primitive():
|
|||
"""
|
||||
op = ops.Add()
|
||||
set_dump(op)
|
||||
assert op.attrs["dump"] == "true"
|
||||
|
||||
|
||||
def test_input_type_check():
|
||||
|
@ -59,3 +66,21 @@ def test_input_type_check():
|
|||
"""
|
||||
with pytest.raises(ValueError):
|
||||
set_dump(1)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Warning can only be triggered once, please execute "
|
||||
"this test case manually.")
|
||||
def test_set_dump_warning():
|
||||
"""
|
||||
Feature: Python API set_dump.
|
||||
Description: Test the warning about device target and mode.
|
||||
Expectation: Trigger warning message.
|
||||
"""
|
||||
context.set_context(device_target="CPU")
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
op = ops.Add()
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
set_dump(op)
|
||||
assert "Only Ascend device target is supported" in str(w[-2].message)
|
||||
assert "Only GRAPH_MODE is supported" in str(w[-1].message)
|
||||
|
|
Loading…
Reference in New Issue