diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc index a5f6dee98dc..ccab6d71201 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc @@ -51,12 +51,12 @@ int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) { input(ctx, ON_VALUE_INDEX).trt_tensor_, input(ctx, OFF_VALUE_INDEX).trt_tensor_}; ITensorHelper indice_helper = input(ctx, 0); if (indice_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in"); + inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in"); } ITensorHelper depth_helper = input(ctx, DEPTH_INDEX); if (depth_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { inputTensors[DEPTH_INDEX] = - TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in"); + TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in"); } mindspore::MSTensor &depth_tensor = in_tensors_[DEPTH_INDEX]; if (depth_tensor.Data() == nullptr) { @@ -65,7 +65,9 @@ int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) { } const int *depth_ptr = reinterpret_cast(depth_tensor.Data().get()); int depth = *depth_ptr; - auto plugin = std::make_shared(op_name_, depth, op_primitive_->value_type()); + auto onehot_op = op_primitive_->value_as_OneHot(); + int axis = onehot_op->axis(); + auto plugin = std::make_shared(op_name_, axis, depth, op_primitive_->value_type()); if (plugin == nullptr) { MS_LOG(ERROR) << "create OnehotPlugin failed for " << op_name_; return RET_ERROR; @@ -98,12 +100,19 @@ int OnehotPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvi int OnehotPlugin::RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, cudaStream_t stream) { + int left_dims = 1; + int right_dims = 1; for (int i = 0; i != inputDesc[0].dims.nbDims; ++i) { - feature_dims_ *= inputDesc[0].dims.d[i]; + if (axis_ == -1 || i < IntToSize(axis_)) { + left_dims *= inputDesc[0].dims.d[i]; + } + if (axis_ != -1 && i >= IntToSize(axis_)) { + right_dims *= inputDesc[0].dims.d[i]; + } } if (inputDesc[0].type == nvinfer1::DataType::kINT32 && inputDesc[ON_VALUE_INDEX].type == nvinfer1::DataType::kFLOAT) { OneHot(static_cast(inputs[0]), depth_, static_cast(inputs[ON_VALUE_INDEX]), - static_cast(inputs[OFF_VALUE_INDEX]), batch_dims_, feature_dims_, + static_cast(inputs[OFF_VALUE_INDEX]), left_dims, right_dims, static_cast(outputs[0]), device_id_, stream); } else { MS_LOG(ERROR) << "invalid onehot type: " << static_cast(primitive_type_); @@ -117,16 +126,22 @@ nvinfer1::DimsExprs OnehotPlugin::getOutputDimensions(int32_t index, const nvinf nvinfer1::DimsExprs dims; dims.nbDims = inputs[0].nbDims + 1; auto indice_dims = inputs[0].nbDims; - if (indice_dims == 1) { - dims.d[0] = inputs[0].d[0]; - auto depth = exprBuilder.constant(depth_); - dims.d[1] = depth; - } else { + if (axis_ == -1) { for (int i = 0; i != inputs[0].nbDims; ++i) { dims.d[i] = inputs[0].d[i]; } auto depth = exprBuilder.constant(depth_); - dims.d[inputs[0].nbDims] = depth; + dims.d[dims.nbDims - 1] = depth; + } else { + for (int i = 0; i != indice_dims; ++i) { + if (i >= axis_) { + dims.d[i + 1] = inputs[0].d[i]; + } else { + dims.d[i] = inputs[0].d[i]; + } + } + auto depth = exprBuilder.constant(depth_); + dims.d[axis_] = depth; } return dims; } @@ -137,11 +152,15 @@ nvinfer1::IPluginV2DynamicExt *OnehotPlugin::clone() const noexcept { return plugin; } -size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 3 * sizeof(int); } +size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 2 * sizeof(int); } + +nvinfer1::DataType OnehotPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const + noexcept { + return inputTypes[ON_VALUE_INDEX]; +} void OnehotPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &batch_dims_, sizeof(int)); - SerializeValue(&buffer, &feature_dims_, sizeof(int)); + SerializeValue(&buffer, &axis_, sizeof(int)); SerializeValue(&buffer, &depth_, sizeof(int)); SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType)); } diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h b/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h index d93dd1ccc0f..539b20adea9 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h @@ -41,20 +41,23 @@ class OnehotTensorRT : public TensorRTOp { constexpr char *ONEHOT_PLUGIN_NAME{"OnehotPlugin"}; class OnehotPlugin : public TensorRTPlugin { public: - OnehotPlugin(const std::string name, int depth, schema::PrimitiveType primitive_type) - : TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)), depth_(depth), primitive_type_(primitive_type) {} + OnehotPlugin(const std::string name, int axis, int depth, schema::PrimitiveType primitive_type) + : TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)), + axis_(axis), + depth_(depth), + primitive_type_(primitive_type) {} OnehotPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) : TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) { const nvinfer1::PluginField *fields = fc->fields; - depth_ = static_cast(fields[0].data)[0]; - primitive_type_ = static_cast(fields[1].data)[0]; + axis_ = static_cast(fields[0].data)[0]; + depth_ = static_cast(fields[1].data)[0]; + primitive_type_ = static_cast(fields[2].data)[0]; } OnehotPlugin(const char *name, const void *serialData, size_t serialLength) : TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &batch_dims_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &feature_dims_, sizeof(int)); + DeserializeValue(&serialData, &serialLength, &axis_, sizeof(int)); DeserializeValue(&serialData, &serialLength, &depth_, sizeof(int)); DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType)); } @@ -65,6 +68,8 @@ class OnehotPlugin : public TensorRTPlugin { int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; size_t getSerializationSize() const noexcept override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const + noexcept override; void serialize(void *buffer) const noexcept override; nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept override; @@ -72,10 +77,8 @@ class OnehotPlugin : public TensorRTPlugin { private: int RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, cudaStream_t stream); - - int batch_dims_{1}; - int feature_dims_{1}; int depth_{1}; + int axis_{-1}; schema::PrimitiveType primitive_type_; }; class OnehotPluginCreater : public TensorRTPluginCreater { diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc index 45ef5fe4e22..d60a39f19f8 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc @@ -40,6 +40,22 @@ int WhereTensorRT::IsSupport(const schema::Primitive *primitive, const std::vect return RET_OK; } +nvinfer1::ITensor *WhereTensorRT::GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor) { + auto input_cond_dims = input(ctx, 0).trt_tensor_->getDimensions(); + nvinfer1::Dims in_tensor_dims = input_tensor->getDimensions(); + while (in_tensor_dims.nbDims < input_cond_dims.nbDims) { + input_tensor = ExpandDim(ctx, input_tensor, 0); + if (input_tensor->getDimensions().nbDims == -1) { + MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; + } + nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor); + shuffle_layer->setReshapeDimensions(input_tensor->getDimensions()); + input_tensor = shuffle_layer->getOutput(0); + in_tensor_dims = input_tensor->getDimensions(); + } + return input_tensor; +} + int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) { if (ctx == nullptr || ctx->network() == nullptr) { MS_LOG(ERROR) << "network or input tensor is invalid"; @@ -61,13 +77,21 @@ int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) { // broadcast to same shape if (input_x_dims.nbDims != input_y_dims.nbDims) { if (input_x_dims.nbDims > input_y_dims.nbDims) { - auto expect_shape = ConvertMSShape(input(ctx, INPUT_X_INDEX).trt_tensor_->getDimensions()); - inputTensors[INPUT_Y_INDEX] = - ConvertConstantTensorWithDims(ctx, in_tensors_[INPUT_Y_INDEX], expect_shape, op_name_ + "_broadcast_inputy"); + auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_X_INDEX).trt_tensor_)->getOutput(0); + auto inputy = GetBroadcastTensor(ctx, input(ctx, INPUT_Y_INDEX).trt_tensor_); + auto size_tensor = ctx->network()->addShape(*inputy)->getOutput(0); + size_tensor = ctx->network() + ->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX) + ->getOutput(0); + inputTensors[INPUT_Y_INDEX] = Broadcast(ctx, inputy, size_tensor); } else { - auto expect_shape = ConvertMSShape(input(ctx, INPUT_Y_INDEX).trt_tensor_->getDimensions()); - inputTensors[INPUT_X_INDEX] = - ConvertConstantTensorWithDims(ctx, in_tensors_[INPUT_X_INDEX], expect_shape, op_name_ + "_broadcast_inputx"); + auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_Y_INDEX).trt_tensor_)->getOutput(0); + auto inputx = GetBroadcastTensor(ctx, input(ctx, INPUT_X_INDEX).trt_tensor_); + auto size_tensor = ctx->network()->addShape(*inputx)->getOutput(0); + size_tensor = ctx->network() + ->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX) + ->getOutput(0); + inputTensors[INPUT_X_INDEX] = Broadcast(ctx, inputx, size_tensor); } } @@ -126,6 +150,11 @@ nvinfer1::IPluginV2DynamicExt *WherePlugin::clone() const noexcept { return plugin; } +nvinfer1::DataType WherePlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const + noexcept { + return inputTypes[1]; +} + size_t WherePlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); } void WherePlugin::serialize(void *buffer) const noexcept { diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h b/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h index b84f159c3cf..a4158d1bd7f 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h @@ -36,6 +36,9 @@ class WhereTensorRT : public TensorRTOp { int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, const std::vector &out_tensors) override; + + private: + nvinfer1::ITensor *GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor); }; constexpr char *WHERE_PLUGIN_NAME{"WherePlugin"}; @@ -62,6 +65,8 @@ class WherePlugin : public TensorRTPlugin { const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; size_t getSerializationSize() const noexcept override; void serialize(void *buffer) const noexcept override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const + noexcept override; private: int RunCudaWhere(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,