reorganize strided_slice_infer code

This commit is contained in:
zhaodezan 2021-03-11 21:11:20 +08:00
parent 9f92c1fbb1
commit c060fcbafb
1 changed files with 191 additions and 143 deletions

View File

@ -21,6 +21,26 @@ const size_t kStridedSliceInputNum = 1;
const size_t kStridedSliceMultiInputNumMin = 3;
const size_t kStridedSliceMultiInputNumMax = 5;
typedef struct StridedSliceTransferBuffer {
int ndim_;
int begins_[MAX_SHAPE_SIZE];
int ends_[MAX_SHAPE_SIZE];
int strides_[MAX_SHAPE_SIZE];
int begins_mask_[MAX_SHAPE_SIZE];
int ends_mask_[MAX_SHAPE_SIZE];
int ellipsis_mask_[MAX_SHAPE_SIZE];
int new_axis_mask_[MAX_SHAPE_SIZE];
int shrink_axis_mask_[MAX_SHAPE_SIZE];
size_t begins_size_;
size_t ends_size_;
size_t strides_size_;
size_t ellipsis_mask_size_;
size_t new_axis_mask_size_;
size_t shrink_axis_mask_size_;
} StridedSliceTransferBuffer;
bool CheckInputs(const TensorC *const *inputs, size_t inputs_size) {
for (size_t i = 1; i < inputs_size; ++i) {
if (inputs[i]->data_ == NULL) {
@ -128,10 +148,8 @@ int HandleAxesInputExist(const TensorC *const *inputs, int *ndim_, int *in_shape
return NNACL_OK;
}
// note: begin, end, stride length are equal, but may less than rank of input
int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
if (outputs_size != kStridedSliceOutputNum) {
return NNACL_PARAM_INVALID;
}
@ -142,6 +160,138 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
if (parameter == NULL || outputs[0] == NULL || inputs[0] == NULL) {
return NNACL_NULL_PTR;
}
return NNACL_OK;
}
void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param) {
for (int i = 0; i < transfer_buffer->ndim_; i++) {
transfer_buffer->begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i);
transfer_buffer->ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i);
transfer_buffer->ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i);
transfer_buffer->new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i);
transfer_buffer->shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i);
}
}
void ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape_,
size_t *in_shape_size) {
for (size_t i = 0; i < transfer_buffer->new_axis_mask_size_; i++) {
if (transfer_buffer->new_axis_mask_[i]) {
transfer_buffer->ndim_ += 1;
ShapeInsert(in_shape_, in_shape_size, i, 1);
transfer_buffer->begins_[i] = 0;
transfer_buffer->ends_[i] = 1;
transfer_buffer->strides_[i] = 1;
ShapePush(transfer_buffer->begins_, &transfer_buffer->begins_size_, 0);
ShapePush(transfer_buffer->ends_, &transfer_buffer->ends_size_, in_shape_[transfer_buffer->ndim_ - 1]);
ShapePush(transfer_buffer->strides_, &transfer_buffer->strides_size_, 1);
transfer_buffer->begins_mask_[i] = false;
transfer_buffer->ends_mask_[i] = false;
transfer_buffer->ellipsis_mask_[i] = false;
transfer_buffer->shrink_axis_mask_[i] = false;
}
}
}
void ApplyBeginMask(StridedSliceTransferBuffer *transfer_buffer) {
for (int i = 0; i < transfer_buffer->ndim_; i++) {
if (transfer_buffer->begins_mask_[i]) {
transfer_buffer->begins_[i] = 0;
}
}
}
void ApplyEndMask(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) {
for (int i = 0; i < transfer_buffer->ndim_; i++) {
if (transfer_buffer->ends_mask_[i]) {
transfer_buffer->ends_[i] = in_shape_[i];
}
}
}
void ApplyEllipsisMask(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) {
for (size_t i = 0; i < transfer_buffer->ellipsis_mask_size_; i++) {
if (transfer_buffer->ellipsis_mask_[i]) {
transfer_buffer->begins_[i] = 0;
transfer_buffer->ends_[i] = in_shape_[i];
break;
}
}
}
void TransIndexToPositive(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) {
for (int i = 0; i < (int)(transfer_buffer->begins_size_); ++i) {
if (transfer_buffer->begins_[i] < 0) {
transfer_buffer->begins_[i] += in_shape_[i];
}
if (transfer_buffer->ends_[i] < 0) {
transfer_buffer->ends_[i] += in_shape_[i];
}
}
}
void ApplyShrinkMask(StridedSliceTransferBuffer *transfer_buffer, int *output_shape, size_t *output_shape_size) {
int old_out_shape[MAX_SHAPE_SIZE];
size_t old_out_shape_size = 0;
ShapeSet(old_out_shape, &old_out_shape_size, output_shape, *output_shape_size);
*output_shape_size = 0;
for (size_t i = 0; i < transfer_buffer->shrink_axis_mask_size_; i++) {
if (transfer_buffer->shrink_axis_mask_[i]) {
transfer_buffer->ends_[i] = transfer_buffer->begins_[i] + 1;
transfer_buffer->strides_[i] = 1;
} else {
ShapePush(output_shape, output_shape_size, old_out_shape[i]);
}
}
for (size_t i = transfer_buffer->shrink_axis_mask_size_; i < old_out_shape_size; i++) {
ShapePush(output_shape, output_shape_size, old_out_shape[i]);
}
}
void TransferBuffer2Param(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape_) {
for (int i = 0; i < transfer_buffer->ndim_; i++) {
param->begins_[i] = transfer_buffer->begins_[i];
param->ends_[i] = transfer_buffer->ends_[i];
param->in_shape_[i] = in_shape_[i];
param->strides_[i] = transfer_buffer->strides_[i];
}
for (int i = transfer_buffer->ndim_; i < param->in_shape_length_; i++) {
param->begins_[i] = 0;
param->ends_[i] = in_shape_[i];
param->in_shape_[i] = in_shape_[i];
param->strides_[i] = 1;
}
}
void InitStridedSliceTransferBuffer(StridedSliceTransferBuffer *transfer_buffer) {
transfer_buffer->begins_size_ = 0;
transfer_buffer->ends_size_ = 0;
transfer_buffer->strides_size_ = 0;
transfer_buffer->ellipsis_mask_size_ = 0;
transfer_buffer->new_axis_mask_size_ = 0;
transfer_buffer->shrink_axis_mask_size_ = 0;
}
void SetMaskSize(StridedSliceTransferBuffer *transfer_buffer) {
transfer_buffer->ellipsis_mask_size_ = transfer_buffer->ndim_;
transfer_buffer->new_axis_mask_size_ = transfer_buffer->ndim_;
transfer_buffer->shrink_axis_mask_size_ = transfer_buffer->ndim_;
transfer_buffer->begins_size_ = transfer_buffer->ndim_;
transfer_buffer->ends_size_ = transfer_buffer->ndim_;
transfer_buffer->strides_size_ = transfer_buffer->ndim_;
}
// note: begin, end, stride length are equal, but may less than rank of input
int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = StrideSlicePreCheck(inputs, inputs_size, outputs, outputs_size, parameter);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif
const TensorC *input = inputs[0];
@ -152,42 +302,29 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
}
int in_shape_[MAX_SHAPE_SIZE];
int begins_[MAX_SHAPE_SIZE];
int ends_[MAX_SHAPE_SIZE];
size_t in_shape_size_ = 0;
if (parameter->infer_flag_) {
ShapeSet(in_shape_, &in_shape_size_, input->shape_, input->shape_size_);
}
size_t begins_size_ = 0;
size_t ends_size_ = 0;
int strides_[MAX_SHAPE_SIZE];
size_t strides_size_ = 0;
int begins_mask_[MAX_SHAPE_SIZE];
int ends_mask_[MAX_SHAPE_SIZE];
int ellipsis_mask_[MAX_SHAPE_SIZE];
size_t ellipsis_mask_size_ = 0;
int new_axis_mask_[MAX_SHAPE_SIZE];
size_t new_axis_mask_size_ = 0;
int shrink_axis_mask_[MAX_SHAPE_SIZE];
size_t shrink_axis_mask_size_ = 0;
size_t in_shape_size = 0;
ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_);
StridedSliceTransferBuffer transfer_buffer;
InitStridedSliceTransferBuffer(&transfer_buffer);
StridedSliceParameter *param = (StridedSliceParameter *)parameter;
param->num_axes_ = in_shape_size_;
param->in_shape_length_ = in_shape_size_;
param->num_axes_ = in_shape_size;
param->in_shape_length_ = in_shape_size;
int ndim_ = 0;
transfer_buffer.ndim_ = 0;
if (inputs_size == kStridedSliceInputNum) {
ndim_ = (int)(param->num_axes_);
for (int i = 0; i < ndim_; i++) {
ShapePush(begins_, &begins_size_, param->begins_[i]);
ShapePush(ends_, &ends_size_, param->ends_[i]);
ShapePush(strides_, &strides_size_, param->strides_[i]);
transfer_buffer.ndim_ = (int)(param->num_axes_);
for (int i = 0; i < transfer_buffer.ndim_; i++) {
ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, param->begins_[i]);
ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, param->ends_[i]);
ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, param->strides_[i]);
}
}
if (!CheckInputs(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
if (inputs_size == 4) {
const TensorC *begin_tensor = inputs[1];
int *begin_data = (int *)(begin_tensor->data_);
@ -198,134 +335,45 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
if (begin_data == NULL || end_data == NULL || stride_data == NULL) {
return NNACL_ERR;
}
ndim_ = GetElementNum(begin_tensor);
for (int i = 0; i < ndim_; ++i) {
ShapePush(begins_, &begins_size_, begin_data[i]);
ShapePush(ends_, &ends_size_, end_data[i]);
ShapePush(strides_, &strides_size_, stride_data[i]);
transfer_buffer.ndim_ = GetElementNum(begin_tensor);
for (int i = 0; i < transfer_buffer.ndim_; ++i) {
ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, begin_data[i]);
ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, end_data[i]);
ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, stride_data[i]);
}
}
if (inputs_size == 5) {
int ret = HandleAxesInputExist(inputs, &ndim_, in_shape_, begins_, strides_, ends_);
int ret = HandleAxesInputExist(inputs, &transfer_buffer.ndim_, in_shape_, transfer_buffer.begins_,
transfer_buffer.strides_, transfer_buffer.ends_);
if (ret != NNACL_OK) {
return ret;
}
}
// set all mask to original input shape
ellipsis_mask_size_ = ndim_;
new_axis_mask_size_ = ndim_;
shrink_axis_mask_size_ = ndim_;
begins_size_ = ndim_;
ends_size_ = ndim_;
strides_size_ = ndim_;
// convert bit to vector
for (int i = 0; i < ndim_; i++) {
begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i);
ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i);
ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i);
new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i);
shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i);
}
// ApplyNewAxisMask();
for (size_t i = 0; i < new_axis_mask_size_; i++) {
if (new_axis_mask_[i]) {
ndim_ += 1;
ShapeInsert(in_shape_, &in_shape_size_, i, 1);
begins_[i] = 0;
ends_[i] = 1;
strides_[i] = 1;
ShapePush(begins_, &begins_size_, 0);
ShapePush(ends_, &ends_size_, in_shape_[ndim_ - 1]);
ShapePush(strides_, &strides_size_, 1);
begins_mask_[i] = false;
ends_mask_[i] = false;
ellipsis_mask_[i] = false;
shrink_axis_mask_[i] = false;
}
}
// ApplyBeginMask();
for (int i = 0; i < ndim_; i++) {
if (begins_mask_[i]) {
begins_[i] = 0;
}
}
// ApplyEndMask();
for (int i = 0; i < ndim_; i++) {
if (ends_mask_[i]) {
ends_[i] = in_shape_[i];
}
}
// ApplyEllipsisMask();
for (size_t i = 0; i < ellipsis_mask_size_; i++) {
if (ellipsis_mask_[i]) {
begins_[i] = 0;
ends_[i] = in_shape_[i];
break;
}
}
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
SetMaskSize(&transfer_buffer);
Bit2Vector(&transfer_buffer, param);
ApplyNewAxisMask(&transfer_buffer, param, in_shape_, &in_shape_size);
ApplyBeginMask(&transfer_buffer);
ApplyEndMask(&transfer_buffer, in_shape_);
ApplyEllipsisMask(&transfer_buffer, in_shape_);
int output_shape[MAX_SHAPE_SIZE];
size_t output_shape_size = 0;
ShapeSet(output_shape, &output_shape_size, in_shape_, in_shape_size_);
// TransIndexToPositive();
for (int i = 0; i < (int)(begins_size_); ++i) {
if (begins_[i] < 0) {
begins_[i] += in_shape_[i];
}
if (ends_[i] < 0) {
ends_[i] += in_shape_[i];
}
}
for (int i = 0; i < ndim_; i++) {
if (strides_[i] == 0) {
ShapeSet(output_shape, &output_shape_size, in_shape_, in_shape_size);
TransIndexToPositive(&transfer_buffer, in_shape_);
for (int i = 0; i < transfer_buffer.ndim_; i++) {
if (transfer_buffer.strides_[i] == 0) {
return NNACL_ERR;
}
output_shape[i] = (ends_[i] - begins_[i] + strides_[i] + (strides_[i] < 0 ? 1 : -1)) / strides_[i];
output_shape[i] = (transfer_buffer.ends_[i] - transfer_buffer.begins_[i] + transfer_buffer.strides_[i] +
(transfer_buffer.strides_[i] < 0 ? 1 : -1)) /
transfer_buffer.strides_[i];
}
// ApplyShrinkMask
int old_out_shape[MAX_SHAPE_SIZE];
size_t old_out_shape_size = 0;
ShapeSet(old_out_shape, &old_out_shape_size, output_shape, output_shape_size);
output_shape_size = 0;
for (size_t i = 0; i < shrink_axis_mask_size_; i++) {
if (shrink_axis_mask_[i]) {
ends_[i] = begins_[i] + 1;
strides_[i] = 1;
} else {
ShapePush(output_shape, &output_shape_size, old_out_shape[i]);
}
}
for (size_t i = shrink_axis_mask_size_; i < old_out_shape_size; i++) {
ShapePush(output_shape, &output_shape_size, old_out_shape[i]);
}
ApplyShrinkMask(&transfer_buffer, output_shape, &output_shape_size);
SetShapeArray(outputs[0], output_shape, output_shape_size);
for (int i = 0; i < ndim_; i++) {
param->begins_[i] = begins_[i];
param->ends_[i] = ends_[i];
param->in_shape_[i] = in_shape_[i];
param->strides_[i] = strides_[i];
}
for (int i = ndim_; i < param->in_shape_length_; i++) {
param->begins_[i] = 0;
param->ends_[i] = in_shape_[i];
param->in_shape_[i] = in_shape_[i];
param->strides_[i] = 1;
}
TransferBuffer2Param(&transfer_buffer, param, in_shape_);
return NNACL_OK;
}