diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index ba92cd49817..0018b5b26c9 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -88,7 +88,7 @@ class WithLossCell(Cell): TypeError: If dtype of `data` or `label` is neither float16 nor float32. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> net = Net() @@ -150,7 +150,7 @@ class WithGradCell(Cell): TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> # For a defined network Net without loss function @@ -304,7 +304,7 @@ class TrainOneStepCell(Cell): TypeError: If `sens` is not a number. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> net = Net() @@ -605,7 +605,7 @@ class WithEvalCell(Cell): TypeError: If `add_cast_fp32` is not a bool. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> # For a defined network Net without loss function @@ -642,7 +642,7 @@ class ParameterUpdate(Cell): KeyError: If parameter with the specified name does not exist. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> network = nn.Dense(3, 4) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index f494d3fdf15..cc4ce5c3271 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -77,7 +77,7 @@ def core(fn=None, **flags): other flag. Default: None. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> net = Net() @@ -271,9 +271,9 @@ class GradOperation(GradOperation_): >>> print(output) (Tensor(shape=[2, 3], dtype=Float32, value= [[ 2.21099997e+00, 5.09999990e-01, 1.49000001e+00], - [ 5.58799982e+00, 2.68000007e+00, 4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= + [ 5.58800030e+00, 2.68000007e+00, 4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= [[ 1.51999998e+00, 2.81999993e+00, 2.14000010e+00], - [ 1.09999990e+00, 2.04999971e+00, 1.54999995e+00], + [ 1.09999990e+00, 2.04999995e+00, 1.54999995e+00], [ 9.00000036e-01, 1.54999995e+00, 1.25000000e+00]])) >>> >>> class GradNetWithWrtParams(nn.Cell):