add LiteKernel quant_type

This commit is contained in:
albert-yan 2022-02-23 17:34:36 +08:00
parent 0ade79cb84
commit 14b9f5797d
31 changed files with 108 additions and 81 deletions

View File

@ -96,6 +96,10 @@ class MS_API Kernel {
/// ///
/// \return kernel's type. /// \return kernel's type.
virtual schema::PrimitiveType type() const { return type_; } virtual schema::PrimitiveType type() const { return type_; }
/// \brief obtain kernel's quant type.
///
/// \return kernel's quant type.
virtual schema::QuantType quant_type() const { return quant_type_; }
/// \brief obtain the primitive of kernel generated by flatbuffers. /// \brief obtain the primitive of kernel generated by flatbuffers.
/// ///
/// \return the primitive of kernel generated by flatbuffers. /// \return the primitive of kernel generated by flatbuffers.
@ -147,6 +151,7 @@ class MS_API Kernel {
const schema::Primitive *primitive_ = nullptr; const schema::Primitive *primitive_ = nullptr;
std::map<std::string, std::string> attrs_; std::map<std::string, std::string> attrs_;
const std::map<std::string, std::map<std::string, std::string>> *config_; const std::map<std::string, std::map<std::string, std::string>> *config_;
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
private: private:
void Initialize(); void Initialize();

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class ActivationTensorRT : public TensorRTOp { class ActivationTensorRT : public TensorRTOp {
public: public:
ActivationTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ActivationTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ActivationTensorRT() override = default; ~ActivationTensorRT() override = default;

View File

@ -26,8 +26,9 @@ namespace mindspore::lite {
class AllGatherTensorRT : public TensorRTOp { class AllGatherTensorRT : public TensorRTOp {
public: public:
AllGatherTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, AllGatherTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~AllGatherTensorRT() override = default; ~AllGatherTensorRT() override = default;

View File

@ -24,8 +24,9 @@ namespace mindspore::lite {
class CastTensorRT : public TensorRTOp { class CastTensorRT : public TensorRTOp {
public: public:
CastTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, CastTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~CastTensorRT() override = default; ~CastTensorRT() override = default;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class ConcateTensorRT : public TensorRTOp { class ConcateTensorRT : public TensorRTOp {
public: public:
ConcateTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ConcateTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ConcateTensorRT() override = default; ~ConcateTensorRT() override = default;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class ConvolutionTensorRT : public TensorRTOp { class ConvolutionTensorRT : public TensorRTOp {
public: public:
ConvolutionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ConvolutionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ConvolutionTensorRT() override; ~ConvolutionTensorRT() override;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class DeconvolutionTensorRT : public TensorRTOp { class DeconvolutionTensorRT : public TensorRTOp {
public: public:
DeconvolutionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, DeconvolutionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~DeconvolutionTensorRT() override; ~DeconvolutionTensorRT() override;

View File

@ -24,8 +24,9 @@ namespace mindspore::lite {
class ElementWiseTensorRT : public TensorRTOp { class ElementWiseTensorRT : public TensorRTOp {
public: public:
ElementWiseTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ElementWiseTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ElementWiseTensorRT() override = default; ~ElementWiseTensorRT() override = default;

View File

@ -24,8 +24,9 @@ namespace mindspore::lite {
class EqualTensorRT : public TensorRTOp { class EqualTensorRT : public TensorRTOp {
public: public:
EqualTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, EqualTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~EqualTensorRT() override = default; ~EqualTensorRT() override = default;

View File

@ -25,8 +25,9 @@ namespace mindspore::lite {
class FullyConnectedTensorRT : public TensorRTOp { class FullyConnectedTensorRT : public TensorRTOp {
public: public:
FullyConnectedTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, FullyConnectedTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~FullyConnectedTensorRT() override = default; ~FullyConnectedTensorRT() override = default;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class GatherTensorRT : public TensorRTOp { class GatherTensorRT : public TensorRTOp {
public: public:
GatherTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, GatherTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~GatherTensorRT() override = default; ~GatherTensorRT() override = default;

View File

@ -61,8 +61,9 @@ struct LstmWeights {
class LSTMTensorRT : public TensorRTOp { class LSTMTensorRT : public TensorRTOp {
public: public:
LSTMTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, LSTMTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~LSTMTensorRT() override = default; ~LSTMTensorRT() override = default;

View File

@ -25,8 +25,9 @@ namespace mindspore::lite {
class MatMulTensorRT : public TensorRTOp { class MatMulTensorRT : public TensorRTOp {
public: public:
MatMulTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, MatMulTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~MatMulTensorRT() override; ~MatMulTensorRT() override;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class PadTensorRT : public TensorRTOp { class PadTensorRT : public TensorRTOp {
public: public:
PadTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, PadTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~PadTensorRT() override = default; ~PadTensorRT() override = default;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class PoolTensorRT : public TensorRTOp { class PoolTensorRT : public TensorRTOp {
public: public:
PoolTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, PoolTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~PoolTensorRT() override = default; ~PoolTensorRT() override = default;

View File

@ -25,8 +25,9 @@ namespace mindspore::lite {
class ReduceTensorRT : public TensorRTOp { class ReduceTensorRT : public TensorRTOp {
public: public:
ReduceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ReduceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ReduceTensorRT() override = default; ~ReduceTensorRT() override = default;

View File

@ -27,8 +27,9 @@ namespace mindspore::lite {
class ReduceScatterTensorRT : public TensorRTOp { class ReduceScatterTensorRT : public TensorRTOp {
public: public:
ReduceScatterTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ReduceScatterTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ReduceScatterTensorRT() override = default; ~ReduceScatterTensorRT() override = default;

View File

@ -25,8 +25,9 @@ namespace mindspore::lite {
class ResizeTensorRT : public TensorRTOp { class ResizeTensorRT : public TensorRTOp {
public: public:
ResizeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ResizeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ResizeTensorRT() override = default; ~ResizeTensorRT() override = default;

View File

@ -25,8 +25,9 @@ namespace mindspore::lite {
class ScaleTensorRT : public TensorRTOp { class ScaleTensorRT : public TensorRTOp {
public: public:
ScaleTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ScaleTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ScaleTensorRT() override = default; ~ScaleTensorRT() override = default;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class ShapeTensorRT : public TensorRTOp { class ShapeTensorRT : public TensorRTOp {
public: public:
ShapeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ShapeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ShapeTensorRT() override = default; ~ShapeTensorRT() override = default;

View File

@ -24,8 +24,9 @@ namespace mindspore::lite {
class ShuffleTensorRT : public TensorRTOp { class ShuffleTensorRT : public TensorRTOp {
public: public:
ShuffleTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, ShuffleTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ShuffleTensorRT() override = default; ~ShuffleTensorRT() override = default;

View File

@ -26,8 +26,9 @@ constexpr int AXIS_INDEX = 3;
class SliceTensorRT : public TensorRTOp { class SliceTensorRT : public TensorRTOp {
public: public:
SliceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, SliceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~SliceTensorRT() override = default; ~SliceTensorRT() override = default;

View File

@ -23,8 +23,9 @@ namespace mindspore::lite {
class SoftMaxTensorRT : public TensorRTOp { class SoftMaxTensorRT : public TensorRTOp {
public: public:
SoftMaxTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, SoftMaxTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~SoftMaxTensorRT() override = default; ~SoftMaxTensorRT() override = default;

View File

@ -37,6 +37,8 @@ std::vector<mindspore::MSTensor> &TensorRTOp::outputs() { return this->out_tenso
schema::PrimitiveType TensorRTOp::type() const { return this->type_; } schema::PrimitiveType TensorRTOp::type() const { return this->type_; }
schema::QuantType TensorRTOp::GetQuantType() const { return this->quant_type_; }
void TensorRTOp::set_in_ops(const std::vector<TensorRTOp *> &in_ops) { this->in_ops_ = in_ops; } void TensorRTOp::set_in_ops(const std::vector<TensorRTOp *> &in_ops) { this->in_ops_ = in_ops; }
void TensorRTOp::set_out_ops(const std::vector<TensorRTOp *> &out_ops) { this->out_ops_ = out_ops; } void TensorRTOp::set_out_ops(const std::vector<TensorRTOp *> &out_ops) { this->out_ops_ = out_ops; }
@ -74,10 +76,16 @@ int TensorRTOp::SetInt8DynamicRange() {
MS_LOG(ERROR) << "input or output tensor empty."; MS_LOG(ERROR) << "input or output tensor empty.";
return RET_ERROR; return RET_ERROR;
} }
if (in_tensors_[0].QuantParams().empty() || out_tensors_[0].QuantParams().empty()) { if (quant_type_ != schema::QuantType_QUANT_ALL) {
MS_LOG(INFO) << op_name_ << " quant param is empty."; MS_LOG(INFO) << "op " << op_name_ << " not quantized.";
return RET_OK; return RET_OK;
} }
if (in_tensors_[0].QuantParams().empty() || out_tensors_[0].QuantParams().empty()) {
MS_LOG(WARNING) << op_name_ << " quant param is empty.";
MS_LOG(WARNING) << "in_tensor quant param size: " << in_tensors_[0].QuantParams().size()
<< " ,out_tensor quant param size: " << out_tensors_[0].QuantParams().size();
}
for (size_t i = 0; i < in_tensors_.size(); i++) { for (size_t i = 0; i < in_tensors_.size(); i++) {
auto tensor = in_tensors_.at(i); auto tensor = in_tensors_.at(i);
if (!tensor.IsConst()) { if (!tensor.IsConst()) {

View File

@ -55,11 +55,12 @@ class TensorRTRuntime;
class TensorRTOp { class TensorRTOp {
public: public:
explicit TensorRTOp(const schema::Primitive *primitive, std::vector<mindspore::MSTensor> in_tensors, explicit TensorRTOp(const schema::Primitive *primitive, std::vector<mindspore::MSTensor> in_tensors,
std::vector<mindspore::MSTensor> out_tensors, std::string name) std::vector<mindspore::MSTensor> out_tensors, std::string name, schema::QuantType quant_type)
: op_primitive_(primitive), : op_primitive_(primitive),
in_tensors_(std::move(in_tensors)), in_tensors_(std::move(in_tensors)),
out_tensors_(std::move(out_tensors)), out_tensors_(std::move(out_tensors)),
op_name_(std::move(name)) { op_name_(std::move(name)),
quant_type_(quant_type) {
if (primitive != nullptr) { if (primitive != nullptr) {
this->type_ = primitive->value_type(); this->type_ = primitive->value_type();
} }
@ -94,6 +95,8 @@ class TensorRTOp {
schema::PrimitiveType type() const; schema::PrimitiveType type() const;
schema::QuantType GetQuantType() const;
void set_in_ops(const std::vector<TensorRTOp *> &in_ops); void set_in_ops(const std::vector<TensorRTOp *> &in_ops);
void set_out_ops(const std::vector<TensorRTOp *> &out_ops); void set_out_ops(const std::vector<TensorRTOp *> &out_ops);
@ -136,6 +139,8 @@ class TensorRTOp {
schema::PrimitiveType type_ = schema::PrimitiveType_NONE; schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
std::vector<BindingHelper> op_binding_tensor_; std::vector<BindingHelper> op_binding_tensor_;
TensorRTRuntime *runtime_{nullptr}; TensorRTRuntime *runtime_{nullptr};
@ -145,8 +150,9 @@ class TensorRTOp {
template <class T> template <class T>
TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) { const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name); const schema::QuantType &quant_type) {
auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name, quant_type);
if (op == nullptr) { if (op == nullptr) {
MS_LOG(WARNING) << "TensorRT is nullptr."; MS_LOG(WARNING) << "TensorRT is nullptr.";
return nullptr; return nullptr;

View File

@ -24,8 +24,9 @@ namespace mindspore::lite {
class TopKTensorRT : public TensorRTOp { class TopKTensorRT : public TensorRTOp {
public: public:
TopKTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, TopKTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~TopKTensorRT() override = default; ~TopKTensorRT() override = default;

View File

@ -24,8 +24,9 @@ namespace mindspore::lite {
class UnaryTensorRT : public TensorRTOp { class UnaryTensorRT : public TensorRTOp {
public: public:
UnaryTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors, UnaryTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
: TensorRTOp(primitive, in_tensors, out_tensors, name) {} const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~UnaryTensorRT() override = default; ~UnaryTensorRT() override = default;

View File

@ -242,7 +242,7 @@ TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schem
auto name = kernel->name(); auto name = kernel->name();
auto node_type = primitive->value_type(); auto node_type = primitive->value_type();
if (op_func_lists_.find(node_type) != op_func_lists_.end()) { if (op_func_lists_.find(node_type) != op_func_lists_.end()) {
TensorRTOp *tensorrt_op = op_func_lists_[node_type](primitive, in_tensors, out_tensors, name); TensorRTOp *tensorrt_op = op_func_lists_[node_type](primitive, in_tensors, out_tensors, name, kernel->quant_type());
if (tensorrt_op == nullptr) { if (tensorrt_op == nullptr) {
return nullptr; return nullptr;
} }

View File

@ -31,7 +31,8 @@
namespace mindspore::lite { namespace mindspore::lite {
typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive, typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name); const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type);
class TensorRTDelegate : public Delegate { class TensorRTDelegate : public Delegate {
public: public:

View File

@ -144,34 +144,12 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
} }
bool TensorRTSubGraph::IsInt8Mode() { bool TensorRTSubGraph::IsInt8Mode() {
bool isInt8Mode = false;
for (auto cur_op : all_ops_) { for (auto cur_op : all_ops_) {
if (isInt8Mode) { if (cur_op->GetQuantType() == schema::QuantType_QUANT_ALL) {
break; return true;
}
for (auto in_tensor : cur_op->inputs()) {
if (cur_op->inputs().front().QuantParams().empty()) {
continue;
}
auto quant_param = cur_op->inputs().front().QuantParams().front();
if (quant_param.max > 0) {
isInt8Mode = true;
break;
}
}
for (auto out_tensor : cur_op->outputs()) {
if (cur_op->outputs().front().QuantParams().empty()) {
continue;
}
auto quant_param = cur_op->outputs().front().QuantParams().front();
if (quant_param.max > 0) {
isInt8Mode = true;
break;
}
} }
} }
return isInt8Mode; return false;
} }
nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) { nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) {

View File

@ -95,6 +95,11 @@ class InnerKernel : public Kernel {
: schema::PrimitiveType_NONE; : schema::PrimitiveType_NONE;
} }
schema::QuantType quant_type() const override {
return (this->op_parameter_ != nullptr) ? schema::QuantType(this->op_parameter_->quant_type_)
: schema::QuantType_QUANT_NONE;
}
const std::vector<mindspore::MSTensor> &inputs() override { const std::vector<mindspore::MSTensor> &inputs() override {
if (inputs_.empty()) { if (inputs_.empty()) {
std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) { std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) {