[MSLITE] Support rank 0 for expanddim and fix bug for fill

This commit is contained in:
张勇贤 2022-09-02 15:15:10 +08:00
parent 9a460222e4
commit 93217d2ef3
3 changed files with 17 additions and 20 deletions

View File

@ -49,14 +49,21 @@ int FillTensorRT::AddInnerOp(TensorRTContext *ctx) {
return RET_ERROR;
}
fill_layer->setInput(0, *input(ctx, 1).trt_tensor_);
auto alpha_tensor =
ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data().get(), in_tensors_[0].DataType(), op_name_ + "_alpha");
nvinfer1::ITensor *alpha_tensor = nullptr;
if (in_tensors_[0].Data() == nullptr) {
alpha_tensor = input(ctx, 0).trt_tensor_;
} else {
alpha_tensor =
ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data().get(), in_tensors_[0].DataType(), op_name_ + "_alpha");
}
fill_layer->setInput(1, *alpha_tensor);
int nbdims = input(ctx, 1).trt_tensor_->getDimensions().d[0];
zeros_ = std::vector<float>(nbdims, 0.f);
nvinfer1::Dims beta_dims{1, {nbdims}};
nvinfer1::Weights weights{ConvertDataType(DataType::kNumberTypeFloat32), &zeros_[0], nbdims};
auto beta_tensor = ctx->network()->addConstant(beta_dims, weights)->getOutput(0);
nvinfer1::ITensor *beta_tensor = nullptr;
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
beta_tensor = ctx->ConvertTo1DTensor(std::vector<int>(nbdims, 0));
} else {
beta_tensor = ctx->ConvertTo1DTensor(std::vector<float>(nbdims, 0.f));
}
fill_layer->setInput(INPUT_SIZE2, *beta_tensor);
nvinfer1::ITensor *out_tensor = fill_layer->getOutput(0);

View File

@ -320,6 +320,10 @@ int ShuffleTensorRT::AddFlattenOp(nvinfer1::IShuffleLayer *shuffle_layer) {
}
int ShuffleTensorRT::AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer) {
if (!input(ctx_, 0).is_tensor_) {
shuffler_output_ = shuffler_input_;
return RET_OK;
}
if (in_tensors_[1].DataType() != DataType::kNumberTypeInt32) {
MS_LOG(WARNING) << op_name_ << " axis tensor data type is " << static_cast<int>(in_tensors_[1].DataType());
}

View File

@ -657,20 +657,6 @@ bool TensorRTSubGraph::ValidInputResizeDims(const nvinfer1::Dims &construct_dims
MS_LOG(ERROR) << "invalid resize input.";
return false;
}
if (input_hw_index_ == -1) {
// only NHWC format support HW resize, otherwise only support batchsize resize
for (int d = 0; d < construct_dims.nbDims; d++) {
if (d != input_batchsize_index_ && construct_dims.d[d] != resize_input_shape[d]) {
MS_LOG(ERROR) << "only support dynamic batch size resize input.";
return false;
}
}
} else if ((input_hw_index_ == 1 && construct_dims.d[DIMENSION_3D] != resize_input_shape[DIMENSION_3D]) ||
(input_hw_index_ == DIMENSION_2D && construct_dims.d[1] != resize_input_shape[1])) {
// input may be nhwc || nchw
MS_LOG(ERROR) << "don't support dynamic channel resize input.";
return false;
}
return true;
}