add LiteKernel quant_type
This commit is contained in:
parent
0ade79cb84
commit
14b9f5797d
|
@ -96,6 +96,10 @@ class MS_API Kernel {
|
|||
///
|
||||
/// \return kernel's type.
|
||||
virtual schema::PrimitiveType type() const { return type_; }
|
||||
/// \brief obtain kernel's quant type.
|
||||
///
|
||||
/// \return kernel's quant type.
|
||||
virtual schema::QuantType quant_type() const { return quant_type_; }
|
||||
/// \brief obtain the primitive of kernel generated by flatbuffers.
|
||||
///
|
||||
/// \return the primitive of kernel generated by flatbuffers.
|
||||
|
@ -147,6 +151,7 @@ class MS_API Kernel {
|
|||
const schema::Primitive *primitive_ = nullptr;
|
||||
std::map<std::string, std::string> attrs_;
|
||||
const std::map<std::string, std::map<std::string, std::string>> *config_;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
|
||||
private:
|
||||
void Initialize();
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class ActivationTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ActivationTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ActivationTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -26,8 +26,9 @@ namespace mindspore::lite {
|
|||
class AllGatherTensorRT : public TensorRTOp {
|
||||
public:
|
||||
AllGatherTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~AllGatherTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -24,8 +24,9 @@ namespace mindspore::lite {
|
|||
class CastTensorRT : public TensorRTOp {
|
||||
public:
|
||||
CastTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~CastTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class ConcateTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ConcateTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ConcateTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class ConvolutionTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ConvolutionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ConvolutionTensorRT() override;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class DeconvolutionTensorRT : public TensorRTOp {
|
||||
public:
|
||||
DeconvolutionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~DeconvolutionTensorRT() override;
|
||||
|
||||
|
|
|
@ -24,8 +24,9 @@ namespace mindspore::lite {
|
|||
class ElementWiseTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ElementWiseTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ElementWiseTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -24,8 +24,9 @@ namespace mindspore::lite {
|
|||
class EqualTensorRT : public TensorRTOp {
|
||||
public:
|
||||
EqualTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~EqualTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -25,8 +25,9 @@ namespace mindspore::lite {
|
|||
class FullyConnectedTensorRT : public TensorRTOp {
|
||||
public:
|
||||
FullyConnectedTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~FullyConnectedTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class GatherTensorRT : public TensorRTOp {
|
||||
public:
|
||||
GatherTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~GatherTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -61,8 +61,9 @@ struct LstmWeights {
|
|||
class LSTMTensorRT : public TensorRTOp {
|
||||
public:
|
||||
LSTMTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~LSTMTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -25,8 +25,9 @@ namespace mindspore::lite {
|
|||
class MatMulTensorRT : public TensorRTOp {
|
||||
public:
|
||||
MatMulTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~MatMulTensorRT() override;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class PadTensorRT : public TensorRTOp {
|
||||
public:
|
||||
PadTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~PadTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class PoolTensorRT : public TensorRTOp {
|
||||
public:
|
||||
PoolTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~PoolTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -25,8 +25,9 @@ namespace mindspore::lite {
|
|||
class ReduceTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ReduceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ReduceTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -27,8 +27,9 @@ namespace mindspore::lite {
|
|||
class ReduceScatterTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ReduceScatterTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ReduceScatterTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -25,8 +25,9 @@ namespace mindspore::lite {
|
|||
class ResizeTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ResizeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ResizeTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -25,8 +25,9 @@ namespace mindspore::lite {
|
|||
class ScaleTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ScaleTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ScaleTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class ShapeTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ShapeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ShapeTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -24,8 +24,9 @@ namespace mindspore::lite {
|
|||
class ShuffleTensorRT : public TensorRTOp {
|
||||
public:
|
||||
ShuffleTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~ShuffleTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -26,8 +26,9 @@ constexpr int AXIS_INDEX = 3;
|
|||
class SliceTensorRT : public TensorRTOp {
|
||||
public:
|
||||
SliceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~SliceTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ namespace mindspore::lite {
|
|||
class SoftMaxTensorRT : public TensorRTOp {
|
||||
public:
|
||||
SoftMaxTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~SoftMaxTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ std::vector<mindspore::MSTensor> &TensorRTOp::outputs() { return this->out_tenso
|
|||
|
||||
schema::PrimitiveType TensorRTOp::type() const { return this->type_; }
|
||||
|
||||
schema::QuantType TensorRTOp::GetQuantType() const { return this->quant_type_; }
|
||||
|
||||
void TensorRTOp::set_in_ops(const std::vector<TensorRTOp *> &in_ops) { this->in_ops_ = in_ops; }
|
||||
|
||||
void TensorRTOp::set_out_ops(const std::vector<TensorRTOp *> &out_ops) { this->out_ops_ = out_ops; }
|
||||
|
@ -74,10 +76,16 @@ int TensorRTOp::SetInt8DynamicRange() {
|
|||
MS_LOG(ERROR) << "input or output tensor empty.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[0].QuantParams().empty() || out_tensors_[0].QuantParams().empty()) {
|
||||
MS_LOG(INFO) << op_name_ << " quant param is empty.";
|
||||
if (quant_type_ != schema::QuantType_QUANT_ALL) {
|
||||
MS_LOG(INFO) << "op " << op_name_ << " not quantized.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (in_tensors_[0].QuantParams().empty() || out_tensors_[0].QuantParams().empty()) {
|
||||
MS_LOG(WARNING) << op_name_ << " quant param is empty.";
|
||||
MS_LOG(WARNING) << "in_tensor quant param size: " << in_tensors_[0].QuantParams().size()
|
||||
<< " ,out_tensor quant param size: " << out_tensors_[0].QuantParams().size();
|
||||
}
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
auto tensor = in_tensors_.at(i);
|
||||
if (!tensor.IsConst()) {
|
||||
|
|
|
@ -55,11 +55,12 @@ class TensorRTRuntime;
|
|||
class TensorRTOp {
|
||||
public:
|
||||
explicit TensorRTOp(const schema::Primitive *primitive, std::vector<mindspore::MSTensor> in_tensors,
|
||||
std::vector<mindspore::MSTensor> out_tensors, std::string name)
|
||||
std::vector<mindspore::MSTensor> out_tensors, std::string name, schema::QuantType quant_type)
|
||||
: op_primitive_(primitive),
|
||||
in_tensors_(std::move(in_tensors)),
|
||||
out_tensors_(std::move(out_tensors)),
|
||||
op_name_(std::move(name)) {
|
||||
op_name_(std::move(name)),
|
||||
quant_type_(quant_type) {
|
||||
if (primitive != nullptr) {
|
||||
this->type_ = primitive->value_type();
|
||||
}
|
||||
|
@ -94,6 +95,8 @@ class TensorRTOp {
|
|||
|
||||
schema::PrimitiveType type() const;
|
||||
|
||||
schema::QuantType GetQuantType() const;
|
||||
|
||||
void set_in_ops(const std::vector<TensorRTOp *> &in_ops);
|
||||
|
||||
void set_out_ops(const std::vector<TensorRTOp *> &out_ops);
|
||||
|
@ -136,6 +139,8 @@ class TensorRTOp {
|
|||
|
||||
schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
|
||||
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
|
||||
std::vector<BindingHelper> op_binding_tensor_;
|
||||
|
||||
TensorRTRuntime *runtime_{nullptr};
|
||||
|
@ -145,8 +150,9 @@ class TensorRTOp {
|
|||
|
||||
template <class T>
|
||||
TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name) {
|
||||
auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name);
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type) {
|
||||
auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name, quant_type);
|
||||
if (op == nullptr) {
|
||||
MS_LOG(WARNING) << "TensorRT is nullptr.";
|
||||
return nullptr;
|
||||
|
|
|
@ -24,8 +24,9 @@ namespace mindspore::lite {
|
|||
class TopKTensorRT : public TensorRTOp {
|
||||
public:
|
||||
TopKTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~TopKTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -24,8 +24,9 @@ namespace mindspore::lite {
|
|||
class UnaryTensorRT : public TensorRTOp {
|
||||
public:
|
||||
UnaryTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~UnaryTensorRT() override = default;
|
||||
|
||||
|
|
|
@ -242,7 +242,7 @@ TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schem
|
|||
auto name = kernel->name();
|
||||
auto node_type = primitive->value_type();
|
||||
if (op_func_lists_.find(node_type) != op_func_lists_.end()) {
|
||||
TensorRTOp *tensorrt_op = op_func_lists_[node_type](primitive, in_tensors, out_tensors, name);
|
||||
TensorRTOp *tensorrt_op = op_func_lists_[node_type](primitive, in_tensors, out_tensors, name, kernel->quant_type());
|
||||
if (tensorrt_op == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@
|
|||
namespace mindspore::lite {
|
||||
typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive,
|
||||
const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name);
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type);
|
||||
|
||||
class TensorRTDelegate : public Delegate {
|
||||
public:
|
||||
|
|
|
@ -144,34 +144,12 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
|
|||
}
|
||||
|
||||
bool TensorRTSubGraph::IsInt8Mode() {
|
||||
bool isInt8Mode = false;
|
||||
for (auto cur_op : all_ops_) {
|
||||
if (isInt8Mode) {
|
||||
break;
|
||||
}
|
||||
for (auto in_tensor : cur_op->inputs()) {
|
||||
if (cur_op->inputs().front().QuantParams().empty()) {
|
||||
continue;
|
||||
}
|
||||
auto quant_param = cur_op->inputs().front().QuantParams().front();
|
||||
if (quant_param.max > 0) {
|
||||
isInt8Mode = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto out_tensor : cur_op->outputs()) {
|
||||
if (cur_op->outputs().front().QuantParams().empty()) {
|
||||
continue;
|
||||
}
|
||||
auto quant_param = cur_op->outputs().front().QuantParams().front();
|
||||
if (quant_param.max > 0) {
|
||||
isInt8Mode = true;
|
||||
break;
|
||||
}
|
||||
if (cur_op->GetQuantType() == schema::QuantType_QUANT_ALL) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return isInt8Mode;
|
||||
return false;
|
||||
}
|
||||
|
||||
nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) {
|
||||
|
|
|
@ -95,6 +95,11 @@ class InnerKernel : public Kernel {
|
|||
: schema::PrimitiveType_NONE;
|
||||
}
|
||||
|
||||
schema::QuantType quant_type() const override {
|
||||
return (this->op_parameter_ != nullptr) ? schema::QuantType(this->op_parameter_->quant_type_)
|
||||
: schema::QuantType_QUANT_NONE;
|
||||
}
|
||||
|
||||
const std::vector<mindspore::MSTensor> &inputs() override {
|
||||
if (inputs_.empty()) {
|
||||
std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) {
|
||||
|
|
Loading…
Reference in New Issue