forked from mindspore-Ecosystem/mindspore
!25488 add the set_dump API
Merge pull request !25488 from wenkai/wk1027setdump1
This commit is contained in:
commit
5af8572cd7
|
@ -12,6 +12,7 @@
|
||||||
"mindspore/mindspore/_check_version.py" "unused-import"
|
"mindspore/mindspore/_check_version.py" "unused-import"
|
||||||
"mindspore/mindspore/_check_version.py" "broad-except"
|
"mindspore/mindspore/_check_version.py" "broad-except"
|
||||||
"mindspore/mindspore/common/parameter.py" "protected-access"
|
"mindspore/mindspore/common/parameter.py" "protected-access"
|
||||||
|
"mindspore/mindspore/common/dtype.py" "undefined-all-variable"
|
||||||
"mindspore/mindspore/context.py" "protected-access"
|
"mindspore/mindspore/context.py" "protected-access"
|
||||||
"mindspore/mindspore/ops/operations" "super-init-not-called"
|
"mindspore/mindspore/ops/operations" "super-init-not-called"
|
||||||
"mindspore/mindspore/ops/operations/_quant_ops.py" "unused-import"
|
"mindspore/mindspore/ops/operations/_quant_ops.py" "unused-import"
|
||||||
|
|
|
@ -15,17 +15,48 @@
|
||||||
"""Top-level reference to dtype of common module."""
|
"""Top-level reference to dtype of common module."""
|
||||||
from . import dtype
|
from . import dtype
|
||||||
from .api import ms_function
|
from .api import ms_function
|
||||||
from .dtype import *
|
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
||||||
|
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
||||||
|
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
||||||
|
uint, number, tensor, string, type_none, tensor_type, Int, \
|
||||||
|
complex64, complex128, dtype_to_nptype, issubclass_, \
|
||||||
|
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
|
||||||
|
from .dump import set_dump
|
||||||
from .parameter import Parameter, ParameterTuple
|
from .parameter import Parameter, ParameterTuple
|
||||||
from .tensor import Tensor, RowTensor, SparseTensor
|
|
||||||
from .seed import set_seed, get_seed
|
from .seed import set_seed, get_seed
|
||||||
|
from .tensor import Tensor, RowTensor, SparseTensor
|
||||||
|
|
||||||
|
# symbols from dtype
|
||||||
|
__all__ = [
|
||||||
|
"int8", "byte",
|
||||||
|
"int16", "short",
|
||||||
|
"int32", "intc",
|
||||||
|
"int64", "intp",
|
||||||
|
"uint8", "ubyte",
|
||||||
|
"uint16", "ushort",
|
||||||
|
"uint32", "uintc",
|
||||||
|
"uint64", "uintp",
|
||||||
|
"float16", "half",
|
||||||
|
"float32", "single",
|
||||||
|
"float64", "double",
|
||||||
|
"bool_", "float_",
|
||||||
|
"list_", "tuple_",
|
||||||
|
"int_", "uint",
|
||||||
|
"number", "tensor",
|
||||||
|
"string", "type_none",
|
||||||
|
"tensor_type",
|
||||||
|
"Type", "Int",
|
||||||
|
"complex64", "complex128",
|
||||||
|
# __method__ from dtype
|
||||||
|
"dtype_to_nptype", "issubclass_", "dtype_to_pytype",
|
||||||
|
"pytype_to_dtype", "get_py_obj_dtype"
|
||||||
|
]
|
||||||
|
|
||||||
__all__ = dtype.__all__
|
|
||||||
__all__.extend([
|
__all__.extend([
|
||||||
"Tensor", "RowTensor", "SparseTensor", # tensor
|
"Tensor", "RowTensor", "SparseTensor", # tensor
|
||||||
'ms_function', # api
|
'ms_function', # api
|
||||||
'Parameter', 'ParameterTuple', # parameter
|
'Parameter', 'ParameterTuple', # parameter
|
||||||
"dtype",
|
"dtype",
|
||||||
"set_seed", "get_seed" # random seed
|
"set_seed", "get_seed", # random seed
|
||||||
])
|
"set_dump"
|
||||||
|
])
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Controlling dump behavior."""
|
||||||
|
|
||||||
|
from mindspore._c_expression import security
|
||||||
|
|
||||||
|
|
||||||
|
def set_dump(target, enabled=True):
|
||||||
|
"""
|
||||||
|
Enable or disable dump for the cell instance and its contents.
|
||||||
|
|
||||||
|
The default enabled status for a cell is False. Please note that this
|
||||||
|
mode takes effect only when the dump_mode field in dump config file is
|
||||||
|
2. See the `dump document <https://mindspore.cn/docs/programming_guide/zh-CN/master/dump_in_graph_mode.html>`_
|
||||||
|
for details.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or
|
||||||
|
deletion.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. This API is only effective for GRAPH_MODE with Ascend backend.
|
||||||
|
2. When input is a cell, this API is only effective for the members of
|
||||||
|
the cell instance. If an operator is not a member of the cell
|
||||||
|
instance, the dump flag will not be set for this operator (e.g.
|
||||||
|
functional operators used directly in construct method). To make
|
||||||
|
this API effective, please use self.some_op = SomeOp() in cell
|
||||||
|
__init__ method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Union[Cell, Primitive]): The Cell instance or Primitive instance
|
||||||
|
to which the dump flag is set.
|
||||||
|
enabled (bool): True means enable dump, False means disable dump.
|
||||||
|
Default: True.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore.nn import Cell
|
||||||
|
>>> class MyNet(Cell):
|
||||||
|
... def __init__(self):
|
||||||
|
... super().__init__()
|
||||||
|
... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
||||||
|
... self.relu1 = nn.ReLU()
|
||||||
|
...
|
||||||
|
... def construct(self, x):
|
||||||
|
... x = self.conv1(x)
|
||||||
|
... x = self.relu1(x)
|
||||||
|
... return x
|
||||||
|
>>> net = MyNet()
|
||||||
|
>>> set_dump(net.conv1)
|
||||||
|
"""
|
||||||
|
if security.enable_security():
|
||||||
|
raise ValueError('The set_dump API is not supported, please recompile '
|
||||||
|
'source without "-s on".')
|
||||||
|
|
||||||
|
import mindspore.nn as nn # avoid circular import
|
||||||
|
from mindspore.ops import Primitive
|
||||||
|
if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
|
||||||
|
raise ValueError(f"The \"target\" parameter must be an instance of "
|
||||||
|
f"Cell or Primitive, "
|
||||||
|
f"but got an instance of {type(target)}.")
|
||||||
|
|
||||||
|
if not isinstance(enabled, bool):
|
||||||
|
raise ValueError("The \"enabled\" parameter must be bool.")
|
||||||
|
|
||||||
|
mode = "true" if enabled else "false"
|
||||||
|
if isinstance(target, nn.Cell):
|
||||||
|
primitives = getattr(target, "_primitives", {})
|
||||||
|
for value in primitives.values():
|
||||||
|
if value:
|
||||||
|
value.add_prim_attr("dump", mode)
|
||||||
|
for cell in target.cells():
|
||||||
|
set_dump(cell, enabled)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(target, Primitive):
|
||||||
|
target.add_prim_attr("dump", mode)
|
||||||
|
return
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Test code for mindspore.common package."""
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Test dump."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as ops
|
||||||
|
from mindspore import set_dump
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_dump_on_cell():
|
||||||
|
"""
|
||||||
|
Feature: Python API set_dump.
|
||||||
|
Description: Use set_dump API on Cell instance.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
class MyNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
net = MyNet()
|
||||||
|
set_dump(net.conv1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_dump_on_primitive():
|
||||||
|
"""
|
||||||
|
Feature: Python API set_dump.
|
||||||
|
Description: Use set_dump API on Primitive instance.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
op = ops.Add()
|
||||||
|
set_dump(op)
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_type_check():
|
||||||
|
"""
|
||||||
|
Feature: Python API set_dump.
|
||||||
|
Description: Use set_dump API on unsupported instance.
|
||||||
|
Expectation: Throw ValueError exception.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
set_dump(1)
|
Loading…
Reference in New Issue