!42863 [LITE] fix onehot output datatype and add axis attribute

Merge pull request !42863 from WangWenzhe/1.8_trt_ops
This commit is contained in:
i-robot 2022-10-10 06:57:52 +00:00 committed by Gitee
commit a8cfa7e3ad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 85 additions and 29 deletions

View File

@ -51,12 +51,12 @@ int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) {
input(ctx, ON_VALUE_INDEX).trt_tensor_, input(ctx, OFF_VALUE_INDEX).trt_tensor_};
ITensorHelper indice_helper = input(ctx, 0);
if (indice_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) {
inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in");
inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in");
}
ITensorHelper depth_helper = input(ctx, DEPTH_INDEX);
if (depth_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) {
inputTensors[DEPTH_INDEX] =
TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in");
TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in");
}
mindspore::MSTensor &depth_tensor = in_tensors_[DEPTH_INDEX];
if (depth_tensor.Data() == nullptr) {
@ -65,7 +65,9 @@ int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) {
}
const int *depth_ptr = reinterpret_cast<const int *>(depth_tensor.Data().get());
int depth = *depth_ptr;
auto plugin = std::make_shared<OnehotPlugin>(op_name_, depth, op_primitive_->value_type());
auto onehot_op = op_primitive_->value_as_OneHot();
int axis = onehot_op->axis();
auto plugin = std::make_shared<OnehotPlugin>(op_name_, axis, depth, op_primitive_->value_type());
if (plugin == nullptr) {
MS_LOG(ERROR) << "create OnehotPlugin failed for " << op_name_;
return RET_ERROR;
@ -98,12 +100,19 @@ int OnehotPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvi
int OnehotPlugin::RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs,
void *const *outputs, cudaStream_t stream) {
int left_dims = 1;
int right_dims = 1;
for (int i = 0; i != inputDesc[0].dims.nbDims; ++i) {
feature_dims_ *= inputDesc[0].dims.d[i];
if (axis_ == -1 || i < IntToSize(axis_)) {
left_dims *= inputDesc[0].dims.d[i];
}
if (axis_ != -1 && i >= IntToSize(axis_)) {
right_dims *= inputDesc[0].dims.d[i];
}
}
if (inputDesc[0].type == nvinfer1::DataType::kINT32 && inputDesc[ON_VALUE_INDEX].type == nvinfer1::DataType::kFLOAT) {
OneHot<float, int>(static_cast<const int *>(inputs[0]), depth_, static_cast<const float *>(inputs[ON_VALUE_INDEX]),
static_cast<const float *>(inputs[OFF_VALUE_INDEX]), batch_dims_, feature_dims_,
static_cast<const float *>(inputs[OFF_VALUE_INDEX]), left_dims, right_dims,
static_cast<float *>(outputs[0]), device_id_, stream);
} else {
MS_LOG(ERROR) << "invalid onehot type: " << static_cast<int>(primitive_type_);
@ -117,16 +126,22 @@ nvinfer1::DimsExprs OnehotPlugin::getOutputDimensions(int32_t index, const nvinf
nvinfer1::DimsExprs dims;
dims.nbDims = inputs[0].nbDims + 1;
auto indice_dims = inputs[0].nbDims;
if (indice_dims == 1) {
dims.d[0] = inputs[0].d[0];
auto depth = exprBuilder.constant(depth_);
dims.d[1] = depth;
} else {
if (axis_ == -1) {
for (int i = 0; i != inputs[0].nbDims; ++i) {
dims.d[i] = inputs[0].d[i];
}
auto depth = exprBuilder.constant(depth_);
dims.d[inputs[0].nbDims] = depth;
dims.d[dims.nbDims - 1] = depth;
} else {
for (int i = 0; i != indice_dims; ++i) {
if (i >= axis_) {
dims.d[i + 1] = inputs[0].d[i];
} else {
dims.d[i] = inputs[0].d[i];
}
}
auto depth = exprBuilder.constant(depth_);
dims.d[axis_] = depth;
}
return dims;
}
@ -137,11 +152,15 @@ nvinfer1::IPluginV2DynamicExt *OnehotPlugin::clone() const noexcept {
return plugin;
}
size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 3 * sizeof(int); }
size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 2 * sizeof(int); }
nvinfer1::DataType OnehotPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept {
return inputTypes[ON_VALUE_INDEX];
}
void OnehotPlugin::serialize(void *buffer) const noexcept {
SerializeValue(&buffer, &batch_dims_, sizeof(int));
SerializeValue(&buffer, &feature_dims_, sizeof(int));
SerializeValue(&buffer, &axis_, sizeof(int));
SerializeValue(&buffer, &depth_, sizeof(int));
SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType));
}

View File

@ -41,20 +41,23 @@ class OnehotTensorRT : public TensorRTOp {
constexpr char *ONEHOT_PLUGIN_NAME{"OnehotPlugin"};
class OnehotPlugin : public TensorRTPlugin {
public:
OnehotPlugin(const std::string name, int depth, schema::PrimitiveType primitive_type)
: TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)), depth_(depth), primitive_type_(primitive_type) {}
OnehotPlugin(const std::string name, int axis, int depth, schema::PrimitiveType primitive_type)
: TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)),
axis_(axis),
depth_(depth),
primitive_type_(primitive_type) {}
OnehotPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
: TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) {
const nvinfer1::PluginField *fields = fc->fields;
depth_ = static_cast<const int *>(fields[0].data)[0];
primitive_type_ = static_cast<const schema::PrimitiveType *>(fields[1].data)[0];
axis_ = static_cast<const int *>(fields[0].data)[0];
depth_ = static_cast<const int *>(fields[1].data)[0];
primitive_type_ = static_cast<const schema::PrimitiveType *>(fields[2].data)[0];
}
OnehotPlugin(const char *name, const void *serialData, size_t serialLength)
: TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) {
DeserializeValue(&serialData, &serialLength, &batch_dims_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &feature_dims_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &axis_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &depth_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType));
}
@ -65,6 +68,8 @@ class OnehotPlugin : public TensorRTPlugin {
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept override;
void serialize(void *buffer) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
@ -72,10 +77,8 @@ class OnehotPlugin : public TensorRTPlugin {
private:
int RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,
cudaStream_t stream);
int batch_dims_{1};
int feature_dims_{1};
int depth_{1};
int axis_{-1};
schema::PrimitiveType primitive_type_;
};
class OnehotPluginCreater : public TensorRTPluginCreater<OnehotPlugin> {

View File

@ -40,6 +40,22 @@ int WhereTensorRT::IsSupport(const schema::Primitive *primitive, const std::vect
return RET_OK;
}
nvinfer1::ITensor *WhereTensorRT::GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor) {
auto input_cond_dims = input(ctx, 0).trt_tensor_->getDimensions();
nvinfer1::Dims in_tensor_dims = input_tensor->getDimensions();
while (in_tensor_dims.nbDims < input_cond_dims.nbDims) {
input_tensor = ExpandDim(ctx, input_tensor, 0);
if (input_tensor->getDimensions().nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
}
nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor);
shuffle_layer->setReshapeDimensions(input_tensor->getDimensions());
input_tensor = shuffle_layer->getOutput(0);
in_tensor_dims = input_tensor->getDimensions();
}
return input_tensor;
}
int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "network or input tensor is invalid";
@ -61,13 +77,21 @@ int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) {
// broadcast to same shape
if (input_x_dims.nbDims != input_y_dims.nbDims) {
if (input_x_dims.nbDims > input_y_dims.nbDims) {
auto expect_shape = ConvertMSShape(input(ctx, INPUT_X_INDEX).trt_tensor_->getDimensions());
inputTensors[INPUT_Y_INDEX] =
ConvertConstantTensorWithDims(ctx, in_tensors_[INPUT_Y_INDEX], expect_shape, op_name_ + "_broadcast_inputy");
auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_X_INDEX).trt_tensor_)->getOutput(0);
auto inputy = GetBroadcastTensor(ctx, input(ctx, INPUT_Y_INDEX).trt_tensor_);
auto size_tensor = ctx->network()->addShape(*inputy)->getOutput(0);
size_tensor = ctx->network()
->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX)
->getOutput(0);
inputTensors[INPUT_Y_INDEX] = Broadcast(ctx, inputy, size_tensor);
} else {
auto expect_shape = ConvertMSShape(input(ctx, INPUT_Y_INDEX).trt_tensor_->getDimensions());
inputTensors[INPUT_X_INDEX] =
ConvertConstantTensorWithDims(ctx, in_tensors_[INPUT_X_INDEX], expect_shape, op_name_ + "_broadcast_inputx");
auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_Y_INDEX).trt_tensor_)->getOutput(0);
auto inputx = GetBroadcastTensor(ctx, input(ctx, INPUT_X_INDEX).trt_tensor_);
auto size_tensor = ctx->network()->addShape(*inputx)->getOutput(0);
size_tensor = ctx->network()
->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX)
->getOutput(0);
inputTensors[INPUT_X_INDEX] = Broadcast(ctx, inputx, size_tensor);
}
}
@ -126,6 +150,11 @@ nvinfer1::IPluginV2DynamicExt *WherePlugin::clone() const noexcept {
return plugin;
}
nvinfer1::DataType WherePlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept {
return inputTypes[1];
}
size_t WherePlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); }
void WherePlugin::serialize(void *buffer) const noexcept {

View File

@ -36,6 +36,9 @@ class WhereTensorRT : public TensorRTOp {
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
nvinfer1::ITensor *GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor);
};
constexpr char *WHERE_PLUGIN_NAME{"WherePlugin"};
@ -62,6 +65,8 @@ class WherePlugin : public TensorRTPlugin {
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept override;
private:
int RunCudaWhere(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,