!6590 fix bugs of op Debug, ReLUV2, EditDistance and Dense
Merge pull request !6590 from lihongkang/v2_master
This commit is contained in:
commit
d56683157d
|
@ -165,7 +165,7 @@ class Dense(Cell):
|
|||
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
||||
|
||||
where :math:`\text{activation}` is the activation function passed as the activation
|
||||
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
|
||||
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
|
||||
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
|
||||
with the same data type as the inputs created by the layer (only if has_bias is True).
|
||||
|
||||
|
|
|
@ -66,12 +66,3 @@ def get_bprop_insert_gradient_of(self):
|
|||
def bprop(x, out, dout):
|
||||
return (f(dout),)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Debug)
|
||||
def get_bprop_debug(self):
|
||||
"""Generate bprop for Debug"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return dout
|
||||
return bprop
|
||||
|
|
|
@ -39,7 +39,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast
|
|||
_VirtualDiv, _GetTensorSlice,
|
||||
_HostAllGather, _HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Debug, Print, Assert)
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast
|
||||
|
||||
|
@ -200,7 +200,6 @@ __all__ = [
|
|||
'ImageSummary',
|
||||
'TensorSummary',
|
||||
'HistogramSummary',
|
||||
"Debug",
|
||||
"Print",
|
||||
"Assert",
|
||||
'InsertGradientOf',
|
||||
|
@ -375,6 +374,7 @@ __all__ = [
|
|||
"ParallelConcat",
|
||||
"Push",
|
||||
"Pull",
|
||||
"ReLUV2",
|
||||
'SparseToDense',
|
||||
]
|
||||
|
||||
|
|
|
@ -3619,6 +3619,12 @@ class EditDistance(PrimitiveWithInfer):
|
|||
Tensor, a dense tensor with rank `R-1` and float32 data type.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops.operations as P
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> class EditDistance(nn.Cell):
|
||||
>>> def __init__(self, hypothesis_shape, truth_shape, normalize=True):
|
||||
>>> super(EditDistance, self).__init__()
|
||||
|
@ -3645,6 +3651,7 @@ class EditDistance(PrimitiveWithInfer):
|
|||
def __init__(self, normalize=True):
|
||||
"""Initialize EditDistance"""
|
||||
self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name)
|
||||
self.set_const_input_indexes([2, 5])
|
||||
|
||||
def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape):
|
||||
validator.check_const_input('hypothesis_shape', h_shape['value'], self.name)
|
||||
|
|
|
@ -18,7 +18,7 @@ from types import FunctionType, MethodType
|
|||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
|
||||
|
||||
def _check_summary_param(name, value, class_name):
|
||||
|
@ -342,32 +342,6 @@ class Print(PrimitiveWithInfer):
|
|||
return mstype.int32
|
||||
|
||||
|
||||
class Debug(Primitive):
|
||||
"""
|
||||
Prints tensor value.
|
||||
|
||||
Inputs:
|
||||
- **value** (Tensor) - The value of tensor.
|
||||
|
||||
Examples:
|
||||
>>> class DebugNN(nn.Cell):
|
||||
>>> def __init__(self,):
|
||||
>>> self.debug = nn.Debug()
|
||||
>>>
|
||||
>>> def construct(self, x, y):
|
||||
>>> x = self.add(x, y)
|
||||
>>> self.debug(x)
|
||||
>>> return x
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class Assert(PrimitiveWithInfer):
|
||||
"""
|
||||
Asserts that the given condition is true.
|
||||
|
|
Loading…
Reference in New Issue