add the set_dump API, see I4AUIR for details

This commit is contained in:
wenkai 2021-10-27 18:16:26 +08:00
parent c063fe67b9
commit d2eca80d40
5 changed files with 202 additions and 5 deletions

View File

@ -12,6 +12,7 @@
"mindspore/mindspore/_check_version.py" "unused-import"
"mindspore/mindspore/_check_version.py" "broad-except"
"mindspore/mindspore/common/parameter.py" "protected-access"
"mindspore/mindspore/common/dtype.py" "undefined-all-variable"
"mindspore/mindspore/context.py" "protected-access"
"mindspore/mindspore/ops/operations" "super-init-not-called"
"mindspore/mindspore/ops/operations/_quant_ops.py" "unused-import"

View File

@ -15,17 +15,48 @@
"""Top-level reference to dtype of common module."""
from . import dtype
from .api import ms_function
from .dtype import *
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
uint, number, tensor, string, type_none, tensor_type, Int, \
complex64, complex128, dtype_to_nptype, issubclass_, \
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
from .dump import set_dump
from .parameter import Parameter, ParameterTuple
from .tensor import Tensor, RowTensor, SparseTensor
from .seed import set_seed, get_seed
from .tensor import Tensor, RowTensor, SparseTensor
# symbols from dtype
__all__ = [
"int8", "byte",
"int16", "short",
"int32", "intc",
"int64", "intp",
"uint8", "ubyte",
"uint16", "ushort",
"uint32", "uintc",
"uint64", "uintp",
"float16", "half",
"float32", "single",
"float64", "double",
"bool_", "float_",
"list_", "tuple_",
"int_", "uint",
"number", "tensor",
"string", "type_none",
"tensor_type",
"Type", "Int",
"complex64", "complex128",
# __method__ from dtype
"dtype_to_nptype", "issubclass_", "dtype_to_pytype",
"pytype_to_dtype", "get_py_obj_dtype"
]
__all__ = dtype.__all__
__all__.extend([
"Tensor", "RowTensor", "SparseTensor", # tensor
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype",
"set_seed", "get_seed" # random seed
])
"set_seed", "get_seed", # random seed
"set_dump"
])

89
mindspore/common/dump.py Normal file
View File

@ -0,0 +1,89 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Controlling dump behavior."""
from mindspore._c_expression import security
def set_dump(target, enabled=True):
"""
Enable or disable dump for the cell instance and its contents.
The default enabled status for a cell is False. Please note that this
mode takes effect only when the dump_mode field in dump config file is
2. See the `dump document <https://mindspore.cn/docs/programming_guide/zh-CN/master/dump_in_graph_mode.html>`_
for details.
.. warning::
This is an experimental prototype that is subject to change and/or
deletion.
Note:
1. This API is only effective for GRAPH_MODE with Ascend backend.
2. When input is a cell, this API is only effective for the members of
the cell instance. If an operator is not a member of the cell
instance, the dump flag will not be set for this operator (e.g.
functional operators used directly in construct method). To make
this API effective, please use self.some_op = SomeOp() in cell
__init__ method.
Args:
target (Union[Cell, Primitive]): The Cell instance or Primitive instance
to which the dump flag is set.
enabled (bool): True means enable dump, False means disable dump.
Default: True.
Examples:
>>> from mindspore.nn import Cell
>>> class MyNet(Cell):
... def __init__(self):
... super().__init__()
... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
... self.relu1 = nn.ReLU()
...
... def construct(self, x):
... x = self.conv1(x)
... x = self.relu1(x)
... return x
>>> net = MyNet()
>>> set_dump(net.conv1)
"""
if security.enable_security():
raise ValueError('The set_dump API is not supported, please recompile '
'source without "-s on".')
import mindspore.nn as nn # avoid circular import
from mindspore.ops import Primitive
if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
raise ValueError(f"The \"target\" parameter must be an instance of "
f"Cell or Primitive, "
f"but got an instance of {type(target)}.")
if not isinstance(enabled, bool):
raise ValueError("The \"enabled\" parameter must be bool.")
mode = "true" if enabled else "false"
if isinstance(target, nn.Cell):
primitives = getattr(target, "_primitives", {})
for value in primitives.values():
if value:
value.add_prim_attr("dump", mode)
for cell in target.cells():
set_dump(cell, enabled)
return
if isinstance(target, Primitive):
target.add_prim_attr("dump", mode)
return

View File

@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test code for mindspore.common package."""

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test dump."""
import pytest
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import set_dump
def test_set_dump_on_cell():
"""
Feature: Python API set_dump.
Description: Use set_dump API on Cell instance.
Expectation: Success.
"""
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
self.relu1 = nn.ReLU()
def construct(self, x):
x = self.conv1(x)
x = self.relu1(x)
return x
net = MyNet()
set_dump(net.conv1)
def test_set_dump_on_primitive():
"""
Feature: Python API set_dump.
Description: Use set_dump API on Primitive instance.
Expectation: Success.
"""
op = ops.Add()
set_dump(op)
def test_input_type_check():
"""
Feature: Python API set_dump.
Description: Use set_dump API on unsupported instance.
Expectation: Throw ValueError exception.
"""
with pytest.raises(ValueError):
set_dump(1)