fixed Inv

This commit is contained in:
jiangjinsheng 2020-06-20 14:46:01 +08:00
parent dffbe6edfc
commit e9d4b9864f
3 changed files with 7 additions and 7 deletions

View File

@ -1044,6 +1044,6 @@ def get_bprop_inv(self):
inv_grad = G.InvGrad()
def bprop(x, out, dout):
dx = inv_grad(x, dout)
dx = inv_grad(out, dout)
return (dx,)
return bprop

View File

@ -2644,7 +2644,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
The length of block_shape is M correspoding to the number of spatial dimensions.
crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
input dimension i+2. It is required that input_shape[i+2]*block_size[i] >= crops[i][0]+crops[i][1].
input dimension i+2. It is required that input_shape[i+2]*block_size[i] > crops[i][0]+crops[i][1].
Inputs:
- **input_x** (Tensor) - The input tensor.

View File

@ -228,20 +228,20 @@ class IOU(PrimitiveWithInfer):
Inputs:
- **anchor_boxes** (Tensor) - Anchor boxes, tensor of shape (N, 4). "N" indicates the number of anchor boxes,
and the value "4" refers to "x0", "x1", "y0", and "y1".
and the value "4" refers to "x0", "x1", "y0", and "y1". Data type must be float16.
- **gt_boxes** (Tensor) - Ground truth boxes, tensor of shape (M, 4). "M" indicates the number of ground
truth boxes, and the value "4" refers to "x0", "x1", "y0", and "y1".
truth boxes, and the value "4" refers to "x0", "x1", "y0", and "y1". Data type must be float16.
Outputs:
Tensor, the 'iou' values, tensor of shape (M, N).
Tensor, the 'iou' values, tensor of shape (M, N), with data type float16.
Raises:
KeyError: When `mode` is not 'iou' or 'iof'.
Examples:
>>> iou = P.IOU()
>>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float32)
>>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float32)
>>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
>>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
>>> iou(anchor_boxes, gt_boxes)
"""