forked from mindspore-Ecosystem/mindspore
add the set_dump API, see I4AUIR for details
This commit is contained in:
parent
c063fe67b9
commit
d2eca80d40
|
@ -12,6 +12,7 @@
|
|||
"mindspore/mindspore/_check_version.py" "unused-import"
|
||||
"mindspore/mindspore/_check_version.py" "broad-except"
|
||||
"mindspore/mindspore/common/parameter.py" "protected-access"
|
||||
"mindspore/mindspore/common/dtype.py" "undefined-all-variable"
|
||||
"mindspore/mindspore/context.py" "protected-access"
|
||||
"mindspore/mindspore/ops/operations" "super-init-not-called"
|
||||
"mindspore/mindspore/ops/operations/_quant_ops.py" "unused-import"
|
||||
|
|
|
@ -15,17 +15,48 @@
|
|||
"""Top-level reference to dtype of common module."""
|
||||
from . import dtype
|
||||
from .api import ms_function
|
||||
from .dtype import *
|
||||
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
||||
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
||||
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
||||
uint, number, tensor, string, type_none, tensor_type, Int, \
|
||||
complex64, complex128, dtype_to_nptype, issubclass_, \
|
||||
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
|
||||
from .dump import set_dump
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import Tensor, RowTensor, SparseTensor
|
||||
from .seed import set_seed, get_seed
|
||||
from .tensor import Tensor, RowTensor, SparseTensor
|
||||
|
||||
# symbols from dtype
|
||||
__all__ = [
|
||||
"int8", "byte",
|
||||
"int16", "short",
|
||||
"int32", "intc",
|
||||
"int64", "intp",
|
||||
"uint8", "ubyte",
|
||||
"uint16", "ushort",
|
||||
"uint32", "uintc",
|
||||
"uint64", "uintp",
|
||||
"float16", "half",
|
||||
"float32", "single",
|
||||
"float64", "double",
|
||||
"bool_", "float_",
|
||||
"list_", "tuple_",
|
||||
"int_", "uint",
|
||||
"number", "tensor",
|
||||
"string", "type_none",
|
||||
"tensor_type",
|
||||
"Type", "Int",
|
||||
"complex64", "complex128",
|
||||
# __method__ from dtype
|
||||
"dtype_to_nptype", "issubclass_", "dtype_to_pytype",
|
||||
"pytype_to_dtype", "get_py_obj_dtype"
|
||||
]
|
||||
|
||||
__all__ = dtype.__all__
|
||||
__all__.extend([
|
||||
"Tensor", "RowTensor", "SparseTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype",
|
||||
"set_seed", "get_seed" # random seed
|
||||
])
|
||||
"set_seed", "get_seed", # random seed
|
||||
"set_dump"
|
||||
])
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Controlling dump behavior."""
|
||||
|
||||
from mindspore._c_expression import security
|
||||
|
||||
|
||||
def set_dump(target, enabled=True):
|
||||
"""
|
||||
Enable or disable dump for the cell instance and its contents.
|
||||
|
||||
The default enabled status for a cell is False. Please note that this
|
||||
mode takes effect only when the dump_mode field in dump config file is
|
||||
2. See the `dump document <https://mindspore.cn/docs/programming_guide/zh-CN/master/dump_in_graph_mode.html>`_
|
||||
for details.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or
|
||||
deletion.
|
||||
|
||||
Note:
|
||||
1. This API is only effective for GRAPH_MODE with Ascend backend.
|
||||
2. When input is a cell, this API is only effective for the members of
|
||||
the cell instance. If an operator is not a member of the cell
|
||||
instance, the dump flag will not be set for this operator (e.g.
|
||||
functional operators used directly in construct method). To make
|
||||
this API effective, please use self.some_op = SomeOp() in cell
|
||||
__init__ method.
|
||||
|
||||
Args:
|
||||
target (Union[Cell, Primitive]): The Cell instance or Primitive instance
|
||||
to which the dump flag is set.
|
||||
enabled (bool): True means enable dump, False means disable dump.
|
||||
Default: True.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.nn import Cell
|
||||
>>> class MyNet(Cell):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
||||
... self.relu1 = nn.ReLU()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... x = self.conv1(x)
|
||||
... x = self.relu1(x)
|
||||
... return x
|
||||
>>> net = MyNet()
|
||||
>>> set_dump(net.conv1)
|
||||
"""
|
||||
if security.enable_security():
|
||||
raise ValueError('The set_dump API is not supported, please recompile '
|
||||
'source without "-s on".')
|
||||
|
||||
import mindspore.nn as nn # avoid circular import
|
||||
from mindspore.ops import Primitive
|
||||
if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
|
||||
raise ValueError(f"The \"target\" parameter must be an instance of "
|
||||
f"Cell or Primitive, "
|
||||
f"but got an instance of {type(target)}.")
|
||||
|
||||
if not isinstance(enabled, bool):
|
||||
raise ValueError("The \"enabled\" parameter must be bool.")
|
||||
|
||||
mode = "true" if enabled else "false"
|
||||
if isinstance(target, nn.Cell):
|
||||
primitives = getattr(target, "_primitives", {})
|
||||
for value in primitives.values():
|
||||
if value:
|
||||
value.add_prim_attr("dump", mode)
|
||||
for cell in target.cells():
|
||||
set_dump(cell, enabled)
|
||||
return
|
||||
|
||||
if isinstance(target, Primitive):
|
||||
target.add_prim_attr("dump", mode)
|
||||
return
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test code for mindspore.common package."""
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test dump."""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import set_dump
|
||||
|
||||
|
||||
def test_set_dump_on_cell():
|
||||
"""
|
||||
Feature: Python API set_dump.
|
||||
Description: Use set_dump API on Cell instance.
|
||||
Expectation: Success.
|
||||
"""
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu1(x)
|
||||
return x
|
||||
|
||||
net = MyNet()
|
||||
set_dump(net.conv1)
|
||||
|
||||
|
||||
def test_set_dump_on_primitive():
|
||||
"""
|
||||
Feature: Python API set_dump.
|
||||
Description: Use set_dump API on Primitive instance.
|
||||
Expectation: Success.
|
||||
"""
|
||||
op = ops.Add()
|
||||
set_dump(op)
|
||||
|
||||
|
||||
def test_input_type_check():
|
||||
"""
|
||||
Feature: Python API set_dump.
|
||||
Description: Use set_dump API on unsupported instance.
|
||||
Expectation: Throw ValueError exception.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
set_dump(1)
|
Loading…
Reference in New Issue