!48732 [lite]improve performance of CropAndResize-vmap

Merge pull request !48732 from 徐安越/master3
This commit is contained in:
i-robot 2023-02-21 06:38:37 +00:00 committed by Gitee
commit 407ea3f64e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 12 additions and 5 deletions

View File

@ -16,10 +16,12 @@
"""image_ops vmap impl."""
from __future__ import absolute_import
import mindspore.numpy as mnp
import numpy as np
from mindspore import Tensor
from mindspore.ops import functional as F
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import image_ops as IMG
from mindspore.ops import constexpr
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
_raise_value_error
@ -90,6 +92,13 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
def get_crop_and_resize_vmap_rule(prim, axis_size):
"""VmapRule for `CropAndResize` operation."""
@constexpr
def get_box_indices_offsets(axis_size, batch_size, num_boxes):
offsets = np.arange(0, axis_size * batch_size, batch_size).astype(np.int32)
offsets = np.reshape(offsets, (axis_size, 1))
offsets = np.broadcast_to(offsets, (axis_size, num_boxes))
return Tensor(offsets)
def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
if is_all_none:
@ -115,10 +124,8 @@ def get_crop_and_resize_vmap_rule(prim, axis_size):
x = _bdim_at_front(x, x_dim, axis_size)
x_shape = F.shape(x)
x = F.reshape(x, (-1,) + x_shape[2:])
counts = mnp.arange(0, axis_size * x_shape[1], x_shape[1])
counts = F.reshape(counts, (axis_size, 1))
counts = F.broadcast_to(counts, (axis_size, num_boxes))
box_indices = F.add(box_indices, counts)
offsets = get_box_indices_offsets(axis_size, x_shape[1], num_boxes)
box_indices = F.add(box_indices, offsets)
box_indices = F.reshape(box_indices, (-1,))
out = prim(x, boxes, box_indices, crop_size)