!49379 [MSLITE] limit batch2space and space2batch op type

Merge pull request !49379 from zhangyongxian/dev_zhangyongxian_sbbs
This commit is contained in:
i-robot 2023-02-25 09:00:11 +00:00 committed by Gitee
commit 7af0370e78
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 10 additions and 0 deletions

View File

@ -76,6 +76,11 @@ class BatchToSpacePlugin : public TensorRTPlugin {
void serialize(void *buffer) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept {
return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT &&
tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
private:
int RunCudaBatchToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,

View File

@ -76,6 +76,11 @@ class SpaceToBatchPlugin : public TensorRTPlugin {
void serialize(void *buffer) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept {
return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT &&
tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
private:
int RunCudaSpaceToBatch(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,