forked from mindspore-Ecosystem/mindspore
!48732 [lite]improve performance of CropAndResize-vmap
Merge pull request !48732 from 徐安越/master3
This commit is contained in:
commit
407ea3f64e
|
@ -16,10 +16,12 @@
|
|||
"""image_ops vmap impl."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import mindspore.numpy as mnp
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import image_ops as IMG
|
||||
from mindspore.ops import constexpr
|
||||
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
||||
_raise_value_error
|
||||
|
||||
|
@ -90,6 +92,13 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
|
|||
def get_crop_and_resize_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `CropAndResize` operation."""
|
||||
|
||||
@constexpr
|
||||
def get_box_indices_offsets(axis_size, batch_size, num_boxes):
|
||||
offsets = np.arange(0, axis_size * batch_size, batch_size).astype(np.int32)
|
||||
offsets = np.reshape(offsets, (axis_size, 1))
|
||||
offsets = np.broadcast_to(offsets, (axis_size, num_boxes))
|
||||
return Tensor(offsets)
|
||||
|
||||
def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
|
||||
if is_all_none:
|
||||
|
@ -115,10 +124,8 @@ def get_crop_and_resize_vmap_rule(prim, axis_size):
|
|||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
x_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1,) + x_shape[2:])
|
||||
counts = mnp.arange(0, axis_size * x_shape[1], x_shape[1])
|
||||
counts = F.reshape(counts, (axis_size, 1))
|
||||
counts = F.broadcast_to(counts, (axis_size, num_boxes))
|
||||
box_indices = F.add(box_indices, counts)
|
||||
offsets = get_box_indices_offsets(axis_size, x_shape[1], num_boxes)
|
||||
box_indices = F.add(box_indices, offsets)
|
||||
box_indices = F.reshape(box_indices, (-1,))
|
||||
out = prim(x, boxes, box_indices, crop_size)
|
||||
|
||||
|
|
Loading…
Reference in New Issue