diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h b/mindspore/lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h index 70eedd39a2a..1b5dc12e680 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h @@ -76,6 +76,11 @@ class BatchToSpacePlugin : public TensorRTPlugin { void serialize(void *buffer) const noexcept override; nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, + int nbOutputs) noexcept { + return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT && + tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } private: int RunCudaBatchToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h b/mindspore/lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h index 145ba52bbb2..67e6786f337 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h @@ -76,6 +76,11 @@ class SpaceToBatchPlugin : public TensorRTPlugin { void serialize(void *buffer) const noexcept override; nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, + int nbOutputs) noexcept { + return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT && + tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } private: int RunCudaSpaceToBatch(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,