diff --git a/include/api/kernel.h b/include/api/kernel.h index 9b322711e06..940a4b35a78 100644 --- a/include/api/kernel.h +++ b/include/api/kernel.h @@ -96,6 +96,10 @@ class MS_API Kernel { /// /// \return kernel's type. virtual schema::PrimitiveType type() const { return type_; } + /// \brief obtain kernel's quant type. + /// + /// \return kernel's quant type. + virtual schema::QuantType quant_type() const { return quant_type_; } /// \brief obtain the primitive of kernel generated by flatbuffers. /// /// \return the primitive of kernel generated by flatbuffers. @@ -147,6 +151,7 @@ class MS_API Kernel { const schema::Primitive *primitive_ = nullptr; std::map attrs_; const std::map> *config_; + schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; private: void Initialize(); diff --git a/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h index 10676b885a9..5efa156e3f8 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class ActivationTensorRT : public TensorRTOp { public: ActivationTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ActivationTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.h index ed38bfd66bc..dbdcb50e080 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/allgather_tensorrt.h @@ -26,8 +26,9 @@ namespace mindspore::lite { class AllGatherTensorRT : public TensorRTOp { public: AllGatherTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~AllGatherTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.h index c6784569d35..2a14c3cea35 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/cast_tensorrt.h @@ -24,8 +24,9 @@ namespace mindspore::lite { class CastTensorRT : public TensorRTOp { public: CastTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~CastTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h index 6b2b3c5e13e..8ab16dcef6c 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class ConcateTensorRT : public TensorRTOp { public: ConcateTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ConcateTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h index b702a477191..660f3303475 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class ConvolutionTensorRT : public TensorRTOp { public: ConvolutionTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ConvolutionTensorRT() override; diff --git a/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.h index 09698d6a28b..dd4ec04beb9 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class DeconvolutionTensorRT : public TensorRTOp { public: DeconvolutionTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~DeconvolutionTensorRT() override; diff --git a/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h index e7bb4990c1b..2e197689016 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h @@ -24,8 +24,9 @@ namespace mindspore::lite { class ElementWiseTensorRT : public TensorRTOp { public: ElementWiseTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ElementWiseTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/equal_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/equal_tensorrt.h index 82637d2d302..2eafec3dc28 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/equal_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/equal_tensorrt.h @@ -24,8 +24,9 @@ namespace mindspore::lite { class EqualTensorRT : public TensorRTOp { public: EqualTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~EqualTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/fullyconnected_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/fullyconnected_tensorrt.h index b13d82251d4..b9c0ed12c33 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/fullyconnected_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/fullyconnected_tensorrt.h @@ -25,8 +25,9 @@ namespace mindspore::lite { class FullyConnectedTensorRT : public TensorRTOp { public: FullyConnectedTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~FullyConnectedTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h index 8fece6eeaf7..7a154298313 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class GatherTensorRT : public TensorRTOp { public: GatherTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~GatherTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/lstm_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/lstm_tensorrt.h index 596f7bdd2b3..c311c9c3597 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/lstm_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/lstm_tensorrt.h @@ -61,8 +61,9 @@ struct LstmWeights { class LSTMTensorRT : public TensorRTOp { public: LSTMTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~LSTMTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h index b7a96c7d8dc..4553dad5a06 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h @@ -25,8 +25,9 @@ namespace mindspore::lite { class MatMulTensorRT : public TensorRTOp { public: MatMulTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~MatMulTensorRT() override; diff --git a/mindspore/lite/src/delegate/tensorrt/op/pad_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/pad_tensorrt.h index 6a7017cfe5b..651fab9f911 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/pad_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/pad_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class PadTensorRT : public TensorRTOp { public: PadTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~PadTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.h index 98fb66ed6a6..5661786f713 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class PoolTensorRT : public TensorRTOp { public: PoolTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~PoolTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h index 5563f9ab24b..e25b7ebe87f 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h @@ -25,8 +25,9 @@ namespace mindspore::lite { class ReduceTensorRT : public TensorRTOp { public: ReduceTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ReduceTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.h index 0121cd2d51d..4ace96dfff0 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/reducescatter_tensorrt.h @@ -27,8 +27,9 @@ namespace mindspore::lite { class ReduceScatterTensorRT : public TensorRTOp { public: ReduceScatterTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ReduceScatterTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/resize_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/resize_tensorrt.h index 804bd26a581..c44555ef461 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/resize_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/resize_tensorrt.h @@ -25,8 +25,9 @@ namespace mindspore::lite { class ResizeTensorRT : public TensorRTOp { public: ResizeTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ResizeTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h index 010825c7197..b3c04808fc2 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h @@ -25,8 +25,9 @@ namespace mindspore::lite { class ScaleTensorRT : public TensorRTOp { public: ScaleTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ScaleTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h index 90aa4c57ce9..cb4e7029104 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class ShapeTensorRT : public TensorRTOp { public: ShapeTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ShapeTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h index 62d86da9536..d2b874d4641 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h @@ -24,8 +24,9 @@ namespace mindspore::lite { class ShuffleTensorRT : public TensorRTOp { public: ShuffleTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~ShuffleTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/slice_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/slice_tensorrt.h index 6b55f0ca43b..a2979f2a44a 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/slice_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/slice_tensorrt.h @@ -26,8 +26,9 @@ constexpr int AXIS_INDEX = 3; class SliceTensorRT : public TensorRTOp { public: SliceTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~SliceTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h index e3b39ed6d57..7942b50aacc 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h @@ -23,8 +23,9 @@ namespace mindspore::lite { class SoftMaxTensorRT : public TensorRTOp { public: SoftMaxTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~SoftMaxTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc index ba57b61b6ec..1aed8d25191 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc @@ -37,6 +37,8 @@ std::vector &TensorRTOp::outputs() { return this->out_tenso schema::PrimitiveType TensorRTOp::type() const { return this->type_; } +schema::QuantType TensorRTOp::GetQuantType() const { return this->quant_type_; } + void TensorRTOp::set_in_ops(const std::vector &in_ops) { this->in_ops_ = in_ops; } void TensorRTOp::set_out_ops(const std::vector &out_ops) { this->out_ops_ = out_ops; } @@ -74,10 +76,16 @@ int TensorRTOp::SetInt8DynamicRange() { MS_LOG(ERROR) << "input or output tensor empty."; return RET_ERROR; } - if (in_tensors_[0].QuantParams().empty() || out_tensors_[0].QuantParams().empty()) { - MS_LOG(INFO) << op_name_ << " quant param is empty."; + if (quant_type_ != schema::QuantType_QUANT_ALL) { + MS_LOG(INFO) << "op " << op_name_ << " not quantized."; return RET_OK; } + + if (in_tensors_[0].QuantParams().empty() || out_tensors_[0].QuantParams().empty()) { + MS_LOG(WARNING) << op_name_ << " quant param is empty."; + MS_LOG(WARNING) << "in_tensor quant param size: " << in_tensors_[0].QuantParams().size() + << " ,out_tensor quant param size: " << out_tensors_[0].QuantParams().size(); + } for (size_t i = 0; i < in_tensors_.size(); i++) { auto tensor = in_tensors_.at(i); if (!tensor.IsConst()) { diff --git a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h index ad7ef31c2ca..8665641e04a 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h +++ b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h @@ -55,11 +55,12 @@ class TensorRTRuntime; class TensorRTOp { public: explicit TensorRTOp(const schema::Primitive *primitive, std::vector in_tensors, - std::vector out_tensors, std::string name) + std::vector out_tensors, std::string name, schema::QuantType quant_type) : op_primitive_(primitive), in_tensors_(std::move(in_tensors)), out_tensors_(std::move(out_tensors)), - op_name_(std::move(name)) { + op_name_(std::move(name)), + quant_type_(quant_type) { if (primitive != nullptr) { this->type_ = primitive->value_type(); } @@ -94,6 +95,8 @@ class TensorRTOp { schema::PrimitiveType type() const; + schema::QuantType GetQuantType() const; + void set_in_ops(const std::vector &in_ops); void set_out_ops(const std::vector &out_ops); @@ -136,6 +139,8 @@ class TensorRTOp { schema::PrimitiveType type_ = schema::PrimitiveType_NONE; + schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; + std::vector op_binding_tensor_; TensorRTRuntime *runtime_{nullptr}; @@ -145,8 +150,9 @@ class TensorRTOp { template TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) { - auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name); + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) { + auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name, quant_type); if (op == nullptr) { MS_LOG(WARNING) << "TensorRT is nullptr."; return nullptr; diff --git a/mindspore/lite/src/delegate/tensorrt/op/topk_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/topk_tensorrt.h index ec98c7bb191..d85c8cfde2d 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/topk_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/topk_tensorrt.h @@ -24,8 +24,9 @@ namespace mindspore::lite { class TopKTensorRT : public TensorRTOp { public: TopKTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~TopKTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h index f877ee4977b..bf92dade401 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h +++ b/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h @@ -24,8 +24,9 @@ namespace mindspore::lite { class UnaryTensorRT : public TensorRTOp { public: UnaryTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type) + : TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {} ~UnaryTensorRT() override = default; diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc index 919854d71a3..347bb1a67e0 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc @@ -242,7 +242,7 @@ TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schem auto name = kernel->name(); auto node_type = primitive->value_type(); if (op_func_lists_.find(node_type) != op_func_lists_.end()) { - TensorRTOp *tensorrt_op = op_func_lists_[node_type](primitive, in_tensors, out_tensors, name); + TensorRTOp *tensorrt_op = op_func_lists_[node_type](primitive, in_tensors, out_tensors, name, kernel->quant_type()); if (tensorrt_op == nullptr) { return nullptr; } diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h index ae3180e3ea4..c7b54388972 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h @@ -31,7 +31,8 @@ namespace mindspore::lite { typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name); + const std::vector &out_tensors, const std::string &name, + const schema::QuantType &quant_type); class TensorRTDelegate : public Delegate { public: diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc index 9494f753dd3..ae47659e762 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc @@ -144,34 +144,12 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) { } bool TensorRTSubGraph::IsInt8Mode() { - bool isInt8Mode = false; for (auto cur_op : all_ops_) { - if (isInt8Mode) { - break; - } - for (auto in_tensor : cur_op->inputs()) { - if (cur_op->inputs().front().QuantParams().empty()) { - continue; - } - auto quant_param = cur_op->inputs().front().QuantParams().front(); - if (quant_param.max > 0) { - isInt8Mode = true; - break; - } - } - - for (auto out_tensor : cur_op->outputs()) { - if (cur_op->outputs().front().QuantParams().empty()) { - continue; - } - auto quant_param = cur_op->outputs().front().QuantParams().front(); - if (quant_param.max > 0) { - isInt8Mode = true; - break; - } + if (cur_op->GetQuantType() == schema::QuantType_QUANT_ALL) { + return true; } } - return isInt8Mode; + return false; } nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) { diff --git a/mindspore/lite/src/inner_kernel.h b/mindspore/lite/src/inner_kernel.h index 0349c8d0066..b5076176f6f 100644 --- a/mindspore/lite/src/inner_kernel.h +++ b/mindspore/lite/src/inner_kernel.h @@ -95,6 +95,11 @@ class InnerKernel : public Kernel { : schema::PrimitiveType_NONE; } + schema::QuantType quant_type() const override { + return (this->op_parameter_ != nullptr) ? schema::QuantType(this->op_parameter_->quant_type_) + : schema::QuantType_QUANT_NONE; + } + const std::vector &inputs() override { if (inputs_.empty()) { std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) {