forked from mindspore-Ecosystem/mindspore
[MSLITE] Support rank 0 for expanddim and fix bug for fill
This commit is contained in:
parent
9a460222e4
commit
93217d2ef3
|
@ -49,14 +49,21 @@ int FillTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
fill_layer->setInput(0, *input(ctx, 1).trt_tensor_);
|
||||
auto alpha_tensor =
|
||||
ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data().get(), in_tensors_[0].DataType(), op_name_ + "_alpha");
|
||||
nvinfer1::ITensor *alpha_tensor = nullptr;
|
||||
if (in_tensors_[0].Data() == nullptr) {
|
||||
alpha_tensor = input(ctx, 0).trt_tensor_;
|
||||
} else {
|
||||
alpha_tensor =
|
||||
ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data().get(), in_tensors_[0].DataType(), op_name_ + "_alpha");
|
||||
}
|
||||
fill_layer->setInput(1, *alpha_tensor);
|
||||
int nbdims = input(ctx, 1).trt_tensor_->getDimensions().d[0];
|
||||
zeros_ = std::vector<float>(nbdims, 0.f);
|
||||
nvinfer1::Dims beta_dims{1, {nbdims}};
|
||||
nvinfer1::Weights weights{ConvertDataType(DataType::kNumberTypeFloat32), &zeros_[0], nbdims};
|
||||
auto beta_tensor = ctx->network()->addConstant(beta_dims, weights)->getOutput(0);
|
||||
nvinfer1::ITensor *beta_tensor = nullptr;
|
||||
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
|
||||
beta_tensor = ctx->ConvertTo1DTensor(std::vector<int>(nbdims, 0));
|
||||
} else {
|
||||
beta_tensor = ctx->ConvertTo1DTensor(std::vector<float>(nbdims, 0.f));
|
||||
}
|
||||
fill_layer->setInput(INPUT_SIZE2, *beta_tensor);
|
||||
|
||||
nvinfer1::ITensor *out_tensor = fill_layer->getOutput(0);
|
||||
|
|
|
@ -320,6 +320,10 @@ int ShuffleTensorRT::AddFlattenOp(nvinfer1::IShuffleLayer *shuffle_layer) {
|
|||
}
|
||||
|
||||
int ShuffleTensorRT::AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer) {
|
||||
if (!input(ctx_, 0).is_tensor_) {
|
||||
shuffler_output_ = shuffler_input_;
|
||||
return RET_OK;
|
||||
}
|
||||
if (in_tensors_[1].DataType() != DataType::kNumberTypeInt32) {
|
||||
MS_LOG(WARNING) << op_name_ << " axis tensor data type is " << static_cast<int>(in_tensors_[1].DataType());
|
||||
}
|
||||
|
|
|
@ -657,20 +657,6 @@ bool TensorRTSubGraph::ValidInputResizeDims(const nvinfer1::Dims &construct_dims
|
|||
MS_LOG(ERROR) << "invalid resize input.";
|
||||
return false;
|
||||
}
|
||||
if (input_hw_index_ == -1) {
|
||||
// only NHWC format support HW resize, otherwise only support batchsize resize
|
||||
for (int d = 0; d < construct_dims.nbDims; d++) {
|
||||
if (d != input_batchsize_index_ && construct_dims.d[d] != resize_input_shape[d]) {
|
||||
MS_LOG(ERROR) << "only support dynamic batch size resize input.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if ((input_hw_index_ == 1 && construct_dims.d[DIMENSION_3D] != resize_input_shape[DIMENSION_3D]) ||
|
||||
(input_hw_index_ == DIMENSION_2D && construct_dims.d[1] != resize_input_shape[1])) {
|
||||
// input may be nhwc || nchw
|
||||
MS_LOG(ERROR) << "don't support dynamic channel resize input.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue