From cc26b5394cfc5297c21906df0d8db983aea9799c Mon Sep 17 00:00:00 2001 From: dayschan Date: Fri, 3 Dec 2021 11:43:45 +0800 Subject: [PATCH] remove api "GraphKernel" and prim "InplaceAssign" --- mindspore/nn/__init__.py | 4 +- mindspore/nn/cell.py | 38 ------ mindspore/ops/_op_impl/akg/ascend/__init__.py | 1 - .../ops/_op_impl/akg/ascend/inplace_assign.py | 41 ------ mindspore/ops/_op_impl/akg/gpu/__init__.py | 1 - .../ops/_op_impl/akg/gpu/inplace_assign.py | 41 ------ mindspore/ops/op_selector.py | 124 ------------------ mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/other_ops.py | 48 ------- 9 files changed, 3 insertions(+), 297 deletions(-) delete mode 100644 mindspore/ops/_op_impl/akg/ascend/inplace_assign.py delete mode 100644 mindspore/ops/_op_impl/akg/gpu/inplace_assign.py delete mode 100644 mindspore/ops/op_selector.py diff --git a/mindspore/nn/__init__.py b/mindspore/nn/__init__.py index 13c54bbd3b2..5ee97f5914e 100644 --- a/mindspore/nn/__init__.py +++ b/mindspore/nn/__init__.py @@ -21,7 +21,7 @@ from . import layer, loss, optim, metrics, wrap, grad, probability, sparse, dyna reinforcement from .learning_rate_schedule import * from .dynamic_lr import * -from .cell import Cell, GraphKernel, GraphCell +from .cell import Cell, GraphCell from .layer import * from .loss import * from .optim import * @@ -31,7 +31,7 @@ from .grad import Jvp, Vjp from .sparse import * from .reinforcement import * -__all__ = ["Cell", "GraphKernel", "GraphCell"] +__all__ = ["Cell", "GraphCell"] __all__.extend(layer.__all__) __all__.extend(loss.__all__) __all__.extend(optim.__all__) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 177006b67b2..286435ffcd5 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -24,7 +24,6 @@ import numpy from mindspore._checkparam import args_type_check from mindspore import log as logger from mindspore.common.parameter import PARAMETER_NAME_DEFAULT -from mindspore.common._decorator import deprecated from mindspore.context import ParallelMode from .. import context from .._c_expression import init_pipeline, Cell_, FuncGraph, MixedPrecisionType @@ -1649,43 +1648,6 @@ class Cell(Cell_): return params -class GraphKernel(Cell): - """ - Base class for GraphKernel. - - A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when - enable_graph_kernel in context is set to True. - - This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead. - - GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as - normal `Cell` objects. - - Args: - auto_prefix (bool): Recursively generate namespaces. Default: True. - flags (dict) : Set graph flags. Default: None. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> class Relu(nn.GraphKernel): - ... def __init__(self): - ... super(Relu, self).__init__() - ... self.max = P.Maximum() - ... - ... def construct(self, x): - ... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) - """ - - @deprecated("1.3", "Cell", True) - def __init__(self, auto_prefix=True, flags=None): - super(GraphKernel, self).__init__(auto_prefix, flags) - - def construct(self): - raise NotImplementedError - - class GraphCell(Cell): """ Base class for running the graph loaded from a MindIR. diff --git a/mindspore/ops/_op_impl/akg/ascend/__init__.py b/mindspore/ops/_op_impl/akg/ascend/__init__.py index 41127a2806a..c8f680b299f 100644 --- a/mindspore/ops/_op_impl/akg/ascend/__init__.py +++ b/mindspore/ops/_op_impl/akg/ascend/__init__.py @@ -24,7 +24,6 @@ from .exp import _exp_akg from .expand_dims import _expand_dims_akg from .greater import _greater_akg from .greater_equal import _greater_equal_akg -from .inplace_assign import _inplace_assign_akg from .less import _less_akg from .less_equal import _less_equal_akg from .log import _log_akg diff --git a/mindspore/ops/_op_impl/akg/ascend/inplace_assign.py b/mindspore/ops/_op_impl/akg/ascend/inplace_assign.py deleted file mode 100644 index 9f76706440e..00000000000 --- a/mindspore/ops/_op_impl/akg/ascend/inplace_assign.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""InplaceAssign op""" -from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT - -op_info = AkgAscendRegOp("InplaceAssign") \ - .fusion_type("ELEMWISE") \ - .input(0, "x") \ - .input(1, "y") \ - .input(2, "z") \ - .output(0, "output") \ - .attr("fake_output", "optional", "bool") \ - .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ - .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ - .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ - .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ - .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ - .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ - .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ - .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ - .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ - .get_op_info() - - -@op_info_register(op_info) -def _inplace_assign_akg(): - """InplaceAssign Akg register""" - return diff --git a/mindspore/ops/_op_impl/akg/gpu/__init__.py b/mindspore/ops/_op_impl/akg/gpu/__init__.py index 05f0cba2e35..f3a18cdc09b 100644 --- a/mindspore/ops/_op_impl/akg/gpu/__init__.py +++ b/mindspore/ops/_op_impl/akg/gpu/__init__.py @@ -15,7 +15,6 @@ """__init__""" from .equal import _equal_akg from .greater_equal import _greater_equal_akg -from .inplace_assign import _inplace_assign_akg from .lessequal import _lessequal_akg from .logical_and import _logical_and_akg from .logical_not import _logical_not_akg diff --git a/mindspore/ops/_op_impl/akg/gpu/inplace_assign.py b/mindspore/ops/_op_impl/akg/gpu/inplace_assign.py deleted file mode 100644 index b1a934acc2b..00000000000 --- a/mindspore/ops/_op_impl/akg/gpu/inplace_assign.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""InplaceAssign op""" -from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT - -op_info = AkgGpuRegOp("InplaceAssign") \ - .fusion_type("ELEMWISE") \ - .input(0, "x") \ - .input(1, "y") \ - .input(2, "z") \ - .output(0, "output") \ - .attr("fake_output", "optional", "bool") \ - .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ - .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ - .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ - .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ - .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ - .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ - .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ - .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ - .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ - .get_op_info() - - -@op_info_register(op_info) -def _inplace_assign_akg(): - """InplaceAssign Akg register""" - return diff --git a/mindspore/ops/op_selector.py b/mindspore/ops/op_selector.py deleted file mode 100644 index a2ebb47983a..00000000000 --- a/mindspore/ops/op_selector.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -A factory class that create op selector instance to config switch on a class, -which can be used to control the switch of op type: GraphKernel or Primitive. -""" -import importlib -import inspect -from mindspore import context -from mindspore.common._decorator import deprecated - - -class _OpSelector: - """ - A helper class, which can be used to choose different type of operator. - - When an instance of this class is called, we return the right operator - according to the context['enable_graph_kernel'] and the name of the - parameter. returned operator will be a GraphKernel op ora Primitive op. - - Args: - op (class): an empty class has an operator name as its class name - config_optype (str): operator type, which must be either 'GraphKernel' - or 'Primitive' - graph_kernel_pkg (str): real operator's package name - primitive_pkg (str): graph kernel operator's package name - - Examples: - >>> class A: pass - >>> selected_op = _OpSelector(A, "GraphKernel", - ... "graph_kernel.ops.pkg", "primitive.ops.pkg") - >>> # selected_op() will call graph_kernel.ops.pkg.A() - """ - GRAPH_KERNEL = "GraphKernel" - PRIMITIVE = "Primitive" - DEFAULT_OP_TYPE = PRIMITIVE - KW_STR = "op_type" - - def __init__(self, op, config_optype, primitive_pkg, graph_kernel_pkg): - self.op_name = op.__name__ - self.config_optype = config_optype - self.graph_kernel_pkg = graph_kernel_pkg - self.primitive_pkg = primitive_pkg - - def __call__(self, *args, **kwargs): - _op_type = _OpSelector.DEFAULT_OP_TYPE - if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"): - if _OpSelector.KW_STR in kwargs: - _op_type = kwargs.get(_OpSelector.KW_STR) - kwargs.pop(_OpSelector.KW_STR, None) - elif self.config_optype is not None: - _op_type = self.config_optype - if _op_type == _OpSelector.GRAPH_KERNEL: - pkg = self.graph_kernel_pkg - else: - pkg = self.primitive_pkg - op = getattr(importlib.import_module(pkg, __package__), self.op_name) - return op(*args, **kwargs) - - -@deprecated("1.3", "basic Primitive", False) -def new_ops_selector(primitive_pkg, graph_kernel_pkg): - """ - A factory method to return an op selector - - When the GraphKernel switch is on: - `context.get_context('enable_graph_kernel') == True`, we have 2 ways to control the op type: - (1). call the real op with an extra parameter `op_type='Primitive'` or `op_type='GraphKernel'` - (2). pass a parameter to the op selector, like `@op_selector('Primitive')` or - `@op_selector('GraphKernel')` - (3). default op type is PRIMITIVE - The order of the highest priority to lowest priority is (1), (2), (3) - If the GraphKernel switch is off, then op_type will always be PRIMITIVE. - - The user-defined GraphKernel Cell is deprecated, this interface will be removed in a future version. - - Args: - primitive_pkg (str): primitive op's package name - graph_kernel_pkg (str): graph kernel op's package name - - Returns: - returns an op selector, which can control what operator should be actually called. - - Examples: - >>> op_selector = new_ops_selector("primitive_pkg.some.path", - ... "graph_kernel_pkg.some.path") - >>> @op_selector - >>> class ReduceSum: pass - """ - - def op_selector(cls_or_optype): - - _primitive_pkg = primitive_pkg - _graph_kernel_pkg = graph_kernel_pkg - - def direct_op_type(): - darg = None - if cls_or_optype is None: - pass - elif not inspect.isclass(cls_or_optype): - darg = cls_or_optype - return darg - - if direct_op_type() is not None: - def deco_cls(_real_cls): - return _OpSelector(_real_cls, direct_op_type(), _primitive_pkg, _graph_kernel_pkg) - return deco_cls - - return _OpSelector(cls_or_optype, direct_op_type(), _primitive_pkg, _graph_kernel_pkg) - - return op_selector diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 088201a0fb1..013008d3001 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -89,7 +89,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink) from . import _quant_ops from ._quant_ops import * -from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, +from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, ConfusionMatrix, PopulationCount, UpdateState, Load, CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight, PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc, ExchangeKeys, GetKeys) diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 8b9f97e65f8..666d2e2c342 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -17,7 +17,6 @@ import functools import mindspore.common._monad as monad from mindspore import log as logger -from mindspore.common._decorator import deprecated from .. import signature as sig from ..._checkparam import Validator as validator, Rel from ...common import dtype as mstype @@ -71,53 +70,6 @@ class Assign(Primitive): self.add_prim_attr('side_effect_mem', True) -class InplaceAssign(PrimitiveWithInfer): - """ - Inplace assign `Parameter` with a value. - This primitive can only use in graph kernel. - - InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead. - - Inputs: - - **variable** (Parameter) - The `Parameter`. - - **value** (Tensor) - The value to be assigned. - - **depend** (Tensor) - The dependent tensor to keep this op connected in graph. - - Outputs: - Tensor, has the same type as original `variable`. - - Raises: - TypeError: If `value` or `depend` is not a Tensor. - - Examples: - >>> class Net(nn.Cell): - ... def __init__(self): - ... super(Net, self).__init__() - ... self.inplace_assign = ops.InplaceAssign() - ... - ... def construct(self, x): - ... val = x - 1.0 - ... ret = x + 2.0 - ... return self.inplace_assign(x, val, ret) - ... - >>> x = Tensor([2.0], mindspore.float32) - >>> net = Net() - >>> output = net(x) - >>> print(output) - """ - @deprecated("1.3", "Assign", False) - @ prim_attr_register - def __init__(self): - """Initialize InplaceAssign.""" - self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) - - def infer_shape(self, x, y, z): - return z - - def infer_dtype(self, x, y, z): - return z - - class Load(PrimitiveWithCheck): """ Load `Parameter` to a value.