forked from mindspore-Ecosystem/mindspore
reorganize strided_slice_infer code
This commit is contained in:
parent
9f92c1fbb1
commit
c060fcbafb
|
@ -21,6 +21,26 @@ const size_t kStridedSliceInputNum = 1;
|
|||
const size_t kStridedSliceMultiInputNumMin = 3;
|
||||
const size_t kStridedSliceMultiInputNumMax = 5;
|
||||
|
||||
typedef struct StridedSliceTransferBuffer {
|
||||
int ndim_;
|
||||
|
||||
int begins_[MAX_SHAPE_SIZE];
|
||||
int ends_[MAX_SHAPE_SIZE];
|
||||
int strides_[MAX_SHAPE_SIZE];
|
||||
int begins_mask_[MAX_SHAPE_SIZE];
|
||||
int ends_mask_[MAX_SHAPE_SIZE];
|
||||
int ellipsis_mask_[MAX_SHAPE_SIZE];
|
||||
int new_axis_mask_[MAX_SHAPE_SIZE];
|
||||
int shrink_axis_mask_[MAX_SHAPE_SIZE];
|
||||
|
||||
size_t begins_size_;
|
||||
size_t ends_size_;
|
||||
size_t strides_size_;
|
||||
size_t ellipsis_mask_size_;
|
||||
size_t new_axis_mask_size_;
|
||||
size_t shrink_axis_mask_size_;
|
||||
} StridedSliceTransferBuffer;
|
||||
|
||||
bool CheckInputs(const TensorC *const *inputs, size_t inputs_size) {
|
||||
for (size_t i = 1; i < inputs_size; ++i) {
|
||||
if (inputs[i]->data_ == NULL) {
|
||||
|
@ -128,10 +148,8 @@ int HandleAxesInputExist(const TensorC *const *inputs, int *ndim_, int *in_shape
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// note: begin, end, stride length are equal, but may less than rank of input
|
||||
int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
#ifdef Debug
|
||||
int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
if (outputs_size != kStridedSliceOutputNum) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
|
@ -142,6 +160,138 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
|
|||
if (parameter == NULL || outputs[0] == NULL || inputs[0] == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param) {
|
||||
for (int i = 0; i < transfer_buffer->ndim_; i++) {
|
||||
transfer_buffer->begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i);
|
||||
transfer_buffer->ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i);
|
||||
transfer_buffer->ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i);
|
||||
transfer_buffer->new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i);
|
||||
transfer_buffer->shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i);
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape_,
|
||||
size_t *in_shape_size) {
|
||||
for (size_t i = 0; i < transfer_buffer->new_axis_mask_size_; i++) {
|
||||
if (transfer_buffer->new_axis_mask_[i]) {
|
||||
transfer_buffer->ndim_ += 1;
|
||||
ShapeInsert(in_shape_, in_shape_size, i, 1);
|
||||
transfer_buffer->begins_[i] = 0;
|
||||
transfer_buffer->ends_[i] = 1;
|
||||
transfer_buffer->strides_[i] = 1;
|
||||
|
||||
ShapePush(transfer_buffer->begins_, &transfer_buffer->begins_size_, 0);
|
||||
ShapePush(transfer_buffer->ends_, &transfer_buffer->ends_size_, in_shape_[transfer_buffer->ndim_ - 1]);
|
||||
ShapePush(transfer_buffer->strides_, &transfer_buffer->strides_size_, 1);
|
||||
|
||||
transfer_buffer->begins_mask_[i] = false;
|
||||
transfer_buffer->ends_mask_[i] = false;
|
||||
transfer_buffer->ellipsis_mask_[i] = false;
|
||||
transfer_buffer->shrink_axis_mask_[i] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyBeginMask(StridedSliceTransferBuffer *transfer_buffer) {
|
||||
for (int i = 0; i < transfer_buffer->ndim_; i++) {
|
||||
if (transfer_buffer->begins_mask_[i]) {
|
||||
transfer_buffer->begins_[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyEndMask(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) {
|
||||
for (int i = 0; i < transfer_buffer->ndim_; i++) {
|
||||
if (transfer_buffer->ends_mask_[i]) {
|
||||
transfer_buffer->ends_[i] = in_shape_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyEllipsisMask(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) {
|
||||
for (size_t i = 0; i < transfer_buffer->ellipsis_mask_size_; i++) {
|
||||
if (transfer_buffer->ellipsis_mask_[i]) {
|
||||
transfer_buffer->begins_[i] = 0;
|
||||
transfer_buffer->ends_[i] = in_shape_[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TransIndexToPositive(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) {
|
||||
for (int i = 0; i < (int)(transfer_buffer->begins_size_); ++i) {
|
||||
if (transfer_buffer->begins_[i] < 0) {
|
||||
transfer_buffer->begins_[i] += in_shape_[i];
|
||||
}
|
||||
if (transfer_buffer->ends_[i] < 0) {
|
||||
transfer_buffer->ends_[i] += in_shape_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyShrinkMask(StridedSliceTransferBuffer *transfer_buffer, int *output_shape, size_t *output_shape_size) {
|
||||
int old_out_shape[MAX_SHAPE_SIZE];
|
||||
size_t old_out_shape_size = 0;
|
||||
ShapeSet(old_out_shape, &old_out_shape_size, output_shape, *output_shape_size);
|
||||
*output_shape_size = 0;
|
||||
for (size_t i = 0; i < transfer_buffer->shrink_axis_mask_size_; i++) {
|
||||
if (transfer_buffer->shrink_axis_mask_[i]) {
|
||||
transfer_buffer->ends_[i] = transfer_buffer->begins_[i] + 1;
|
||||
transfer_buffer->strides_[i] = 1;
|
||||
} else {
|
||||
ShapePush(output_shape, output_shape_size, old_out_shape[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = transfer_buffer->shrink_axis_mask_size_; i < old_out_shape_size; i++) {
|
||||
ShapePush(output_shape, output_shape_size, old_out_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TransferBuffer2Param(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape_) {
|
||||
for (int i = 0; i < transfer_buffer->ndim_; i++) {
|
||||
param->begins_[i] = transfer_buffer->begins_[i];
|
||||
param->ends_[i] = transfer_buffer->ends_[i];
|
||||
param->in_shape_[i] = in_shape_[i];
|
||||
param->strides_[i] = transfer_buffer->strides_[i];
|
||||
}
|
||||
|
||||
for (int i = transfer_buffer->ndim_; i < param->in_shape_length_; i++) {
|
||||
param->begins_[i] = 0;
|
||||
param->ends_[i] = in_shape_[i];
|
||||
param->in_shape_[i] = in_shape_[i];
|
||||
param->strides_[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void InitStridedSliceTransferBuffer(StridedSliceTransferBuffer *transfer_buffer) {
|
||||
transfer_buffer->begins_size_ = 0;
|
||||
transfer_buffer->ends_size_ = 0;
|
||||
transfer_buffer->strides_size_ = 0;
|
||||
transfer_buffer->ellipsis_mask_size_ = 0;
|
||||
transfer_buffer->new_axis_mask_size_ = 0;
|
||||
transfer_buffer->shrink_axis_mask_size_ = 0;
|
||||
}
|
||||
|
||||
void SetMaskSize(StridedSliceTransferBuffer *transfer_buffer) {
|
||||
transfer_buffer->ellipsis_mask_size_ = transfer_buffer->ndim_;
|
||||
transfer_buffer->new_axis_mask_size_ = transfer_buffer->ndim_;
|
||||
transfer_buffer->shrink_axis_mask_size_ = transfer_buffer->ndim_;
|
||||
transfer_buffer->begins_size_ = transfer_buffer->ndim_;
|
||||
transfer_buffer->ends_size_ = transfer_buffer->ndim_;
|
||||
transfer_buffer->strides_size_ = transfer_buffer->ndim_;
|
||||
}
|
||||
|
||||
// note: begin, end, stride length are equal, but may less than rank of input
|
||||
int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
#ifdef Debug
|
||||
int check_ret = StrideSlicePreCheck(inputs, inputs_size, outputs, outputs_size, parameter);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
const TensorC *input = inputs[0];
|
||||
|
@ -152,42 +302,29 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
|
|||
}
|
||||
|
||||
int in_shape_[MAX_SHAPE_SIZE];
|
||||
int begins_[MAX_SHAPE_SIZE];
|
||||
int ends_[MAX_SHAPE_SIZE];
|
||||
size_t in_shape_size_ = 0;
|
||||
if (parameter->infer_flag_) {
|
||||
ShapeSet(in_shape_, &in_shape_size_, input->shape_, input->shape_size_);
|
||||
}
|
||||
size_t begins_size_ = 0;
|
||||
size_t ends_size_ = 0;
|
||||
int strides_[MAX_SHAPE_SIZE];
|
||||
size_t strides_size_ = 0;
|
||||
int begins_mask_[MAX_SHAPE_SIZE];
|
||||
int ends_mask_[MAX_SHAPE_SIZE];
|
||||
int ellipsis_mask_[MAX_SHAPE_SIZE];
|
||||
size_t ellipsis_mask_size_ = 0;
|
||||
int new_axis_mask_[MAX_SHAPE_SIZE];
|
||||
size_t new_axis_mask_size_ = 0;
|
||||
int shrink_axis_mask_[MAX_SHAPE_SIZE];
|
||||
size_t shrink_axis_mask_size_ = 0;
|
||||
size_t in_shape_size = 0;
|
||||
ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_);
|
||||
|
||||
StridedSliceTransferBuffer transfer_buffer;
|
||||
InitStridedSliceTransferBuffer(&transfer_buffer);
|
||||
|
||||
StridedSliceParameter *param = (StridedSliceParameter *)parameter;
|
||||
param->num_axes_ = in_shape_size_;
|
||||
param->in_shape_length_ = in_shape_size_;
|
||||
param->num_axes_ = in_shape_size;
|
||||
param->in_shape_length_ = in_shape_size;
|
||||
|
||||
int ndim_ = 0;
|
||||
transfer_buffer.ndim_ = 0;
|
||||
if (inputs_size == kStridedSliceInputNum) {
|
||||
ndim_ = (int)(param->num_axes_);
|
||||
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
ShapePush(begins_, &begins_size_, param->begins_[i]);
|
||||
ShapePush(ends_, &ends_size_, param->ends_[i]);
|
||||
ShapePush(strides_, &strides_size_, param->strides_[i]);
|
||||
transfer_buffer.ndim_ = (int)(param->num_axes_);
|
||||
for (int i = 0; i < transfer_buffer.ndim_; i++) {
|
||||
ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, param->begins_[i]);
|
||||
ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, param->ends_[i]);
|
||||
ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, param->strides_[i]);
|
||||
}
|
||||
}
|
||||
if (!CheckInputs(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
if (inputs_size == 4) {
|
||||
const TensorC *begin_tensor = inputs[1];
|
||||
int *begin_data = (int *)(begin_tensor->data_);
|
||||
|
@ -198,134 +335,45 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
|
|||
if (begin_data == NULL || end_data == NULL || stride_data == NULL) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
ndim_ = GetElementNum(begin_tensor);
|
||||
for (int i = 0; i < ndim_; ++i) {
|
||||
ShapePush(begins_, &begins_size_, begin_data[i]);
|
||||
ShapePush(ends_, &ends_size_, end_data[i]);
|
||||
ShapePush(strides_, &strides_size_, stride_data[i]);
|
||||
transfer_buffer.ndim_ = GetElementNum(begin_tensor);
|
||||
for (int i = 0; i < transfer_buffer.ndim_; ++i) {
|
||||
ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, begin_data[i]);
|
||||
ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, end_data[i]);
|
||||
ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, stride_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs_size == 5) {
|
||||
int ret = HandleAxesInputExist(inputs, &ndim_, in_shape_, begins_, strides_, ends_);
|
||||
int ret = HandleAxesInputExist(inputs, &transfer_buffer.ndim_, in_shape_, transfer_buffer.begins_,
|
||||
transfer_buffer.strides_, transfer_buffer.ends_);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
// set all mask to original input shape
|
||||
ellipsis_mask_size_ = ndim_;
|
||||
new_axis_mask_size_ = ndim_;
|
||||
shrink_axis_mask_size_ = ndim_;
|
||||
begins_size_ = ndim_;
|
||||
ends_size_ = ndim_;
|
||||
strides_size_ = ndim_;
|
||||
|
||||
// convert bit to vector
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i);
|
||||
ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i);
|
||||
ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i);
|
||||
new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i);
|
||||
shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i);
|
||||
}
|
||||
|
||||
// ApplyNewAxisMask();
|
||||
for (size_t i = 0; i < new_axis_mask_size_; i++) {
|
||||
if (new_axis_mask_[i]) {
|
||||
ndim_ += 1;
|
||||
ShapeInsert(in_shape_, &in_shape_size_, i, 1);
|
||||
begins_[i] = 0;
|
||||
ends_[i] = 1;
|
||||
strides_[i] = 1;
|
||||
|
||||
ShapePush(begins_, &begins_size_, 0);
|
||||
ShapePush(ends_, &ends_size_, in_shape_[ndim_ - 1]);
|
||||
ShapePush(strides_, &strides_size_, 1);
|
||||
|
||||
begins_mask_[i] = false;
|
||||
ends_mask_[i] = false;
|
||||
ellipsis_mask_[i] = false;
|
||||
shrink_axis_mask_[i] = false;
|
||||
}
|
||||
}
|
||||
// ApplyBeginMask();
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
if (begins_mask_[i]) {
|
||||
begins_[i] = 0;
|
||||
}
|
||||
}
|
||||
// ApplyEndMask();
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
if (ends_mask_[i]) {
|
||||
ends_[i] = in_shape_[i];
|
||||
}
|
||||
}
|
||||
// ApplyEllipsisMask();
|
||||
for (size_t i = 0; i < ellipsis_mask_size_; i++) {
|
||||
if (ellipsis_mask_[i]) {
|
||||
begins_[i] = 0;
|
||||
ends_[i] = in_shape_[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
SetMaskSize(&transfer_buffer);
|
||||
Bit2Vector(&transfer_buffer, param);
|
||||
ApplyNewAxisMask(&transfer_buffer, param, in_shape_, &in_shape_size);
|
||||
ApplyBeginMask(&transfer_buffer);
|
||||
ApplyEndMask(&transfer_buffer, in_shape_);
|
||||
ApplyEllipsisMask(&transfer_buffer, in_shape_);
|
||||
|
||||
int output_shape[MAX_SHAPE_SIZE];
|
||||
size_t output_shape_size = 0;
|
||||
ShapeSet(output_shape, &output_shape_size, in_shape_, in_shape_size_);
|
||||
|
||||
// TransIndexToPositive();
|
||||
for (int i = 0; i < (int)(begins_size_); ++i) {
|
||||
if (begins_[i] < 0) {
|
||||
begins_[i] += in_shape_[i];
|
||||
}
|
||||
if (ends_[i] < 0) {
|
||||
ends_[i] += in_shape_[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
if (strides_[i] == 0) {
|
||||
ShapeSet(output_shape, &output_shape_size, in_shape_, in_shape_size);
|
||||
TransIndexToPositive(&transfer_buffer, in_shape_);
|
||||
for (int i = 0; i < transfer_buffer.ndim_; i++) {
|
||||
if (transfer_buffer.strides_[i] == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
output_shape[i] = (ends_[i] - begins_[i] + strides_[i] + (strides_[i] < 0 ? 1 : -1)) / strides_[i];
|
||||
output_shape[i] = (transfer_buffer.ends_[i] - transfer_buffer.begins_[i] + transfer_buffer.strides_[i] +
|
||||
(transfer_buffer.strides_[i] < 0 ? 1 : -1)) /
|
||||
transfer_buffer.strides_[i];
|
||||
}
|
||||
|
||||
// ApplyShrinkMask
|
||||
int old_out_shape[MAX_SHAPE_SIZE];
|
||||
size_t old_out_shape_size = 0;
|
||||
ShapeSet(old_out_shape, &old_out_shape_size, output_shape, output_shape_size);
|
||||
output_shape_size = 0;
|
||||
for (size_t i = 0; i < shrink_axis_mask_size_; i++) {
|
||||
if (shrink_axis_mask_[i]) {
|
||||
ends_[i] = begins_[i] + 1;
|
||||
strides_[i] = 1;
|
||||
} else {
|
||||
ShapePush(output_shape, &output_shape_size, old_out_shape[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = shrink_axis_mask_size_; i < old_out_shape_size; i++) {
|
||||
ShapePush(output_shape, &output_shape_size, old_out_shape[i]);
|
||||
}
|
||||
|
||||
ApplyShrinkMask(&transfer_buffer, output_shape, &output_shape_size);
|
||||
SetShapeArray(outputs[0], output_shape, output_shape_size);
|
||||
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
param->begins_[i] = begins_[i];
|
||||
param->ends_[i] = ends_[i];
|
||||
param->in_shape_[i] = in_shape_[i];
|
||||
param->strides_[i] = strides_[i];
|
||||
}
|
||||
|
||||
for (int i = ndim_; i < param->in_shape_length_; i++) {
|
||||
param->begins_[i] = 0;
|
||||
param->ends_[i] = in_shape_[i];
|
||||
param->in_shape_[i] = in_shape_[i];
|
||||
param->strides_[i] = 1;
|
||||
}
|
||||
TransferBuffer2Param(&transfer_buffer, param, in_shape_);
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue