forked from mindspore-Ecosystem/mindspore
!49379 [MSLITE] limit batch2space and space2batch op type
Merge pull request !49379 from zhangyongxian/dev_zhangyongxian_sbbs
This commit is contained in:
commit
7af0370e78
|
@ -76,6 +76,11 @@ class BatchToSpacePlugin : public TensorRTPlugin {
|
|||
void serialize(void *buffer) const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept {
|
||||
return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
private:
|
||||
int RunCudaBatchToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,
|
||||
|
|
|
@ -76,6 +76,11 @@ class SpaceToBatchPlugin : public TensorRTPlugin {
|
|||
void serialize(void *buffer) const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept {
|
||||
return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
private:
|
||||
int RunCudaSpaceToBatch(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,
|
||||
|
|
Loading…
Reference in New Issue