forked from mindspore-Ecosystem/mindspore
!42863 [LITE] fix onehot output datatype and add axis attribute
Merge pull request !42863 from WangWenzhe/1.8_trt_ops
This commit is contained in:
commit
a8cfa7e3ad
|
@ -51,12 +51,12 @@ int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
input(ctx, ON_VALUE_INDEX).trt_tensor_, input(ctx, OFF_VALUE_INDEX).trt_tensor_};
|
||||
ITensorHelper indice_helper = input(ctx, 0);
|
||||
if (indice_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) {
|
||||
inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in");
|
||||
inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in");
|
||||
}
|
||||
ITensorHelper depth_helper = input(ctx, DEPTH_INDEX);
|
||||
if (depth_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) {
|
||||
inputTensors[DEPTH_INDEX] =
|
||||
TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in");
|
||||
TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in");
|
||||
}
|
||||
mindspore::MSTensor &depth_tensor = in_tensors_[DEPTH_INDEX];
|
||||
if (depth_tensor.Data() == nullptr) {
|
||||
|
@ -65,7 +65,9 @@ int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
}
|
||||
const int *depth_ptr = reinterpret_cast<const int *>(depth_tensor.Data().get());
|
||||
int depth = *depth_ptr;
|
||||
auto plugin = std::make_shared<OnehotPlugin>(op_name_, depth, op_primitive_->value_type());
|
||||
auto onehot_op = op_primitive_->value_as_OneHot();
|
||||
int axis = onehot_op->axis();
|
||||
auto plugin = std::make_shared<OnehotPlugin>(op_name_, axis, depth, op_primitive_->value_type());
|
||||
if (plugin == nullptr) {
|
||||
MS_LOG(ERROR) << "create OnehotPlugin failed for " << op_name_;
|
||||
return RET_ERROR;
|
||||
|
@ -98,12 +100,19 @@ int OnehotPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvi
|
|||
|
||||
int OnehotPlugin::RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs,
|
||||
void *const *outputs, cudaStream_t stream) {
|
||||
int left_dims = 1;
|
||||
int right_dims = 1;
|
||||
for (int i = 0; i != inputDesc[0].dims.nbDims; ++i) {
|
||||
feature_dims_ *= inputDesc[0].dims.d[i];
|
||||
if (axis_ == -1 || i < IntToSize(axis_)) {
|
||||
left_dims *= inputDesc[0].dims.d[i];
|
||||
}
|
||||
if (axis_ != -1 && i >= IntToSize(axis_)) {
|
||||
right_dims *= inputDesc[0].dims.d[i];
|
||||
}
|
||||
}
|
||||
if (inputDesc[0].type == nvinfer1::DataType::kINT32 && inputDesc[ON_VALUE_INDEX].type == nvinfer1::DataType::kFLOAT) {
|
||||
OneHot<float, int>(static_cast<const int *>(inputs[0]), depth_, static_cast<const float *>(inputs[ON_VALUE_INDEX]),
|
||||
static_cast<const float *>(inputs[OFF_VALUE_INDEX]), batch_dims_, feature_dims_,
|
||||
static_cast<const float *>(inputs[OFF_VALUE_INDEX]), left_dims, right_dims,
|
||||
static_cast<float *>(outputs[0]), device_id_, stream);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid onehot type: " << static_cast<int>(primitive_type_);
|
||||
|
@ -117,16 +126,22 @@ nvinfer1::DimsExprs OnehotPlugin::getOutputDimensions(int32_t index, const nvinf
|
|||
nvinfer1::DimsExprs dims;
|
||||
dims.nbDims = inputs[0].nbDims + 1;
|
||||
auto indice_dims = inputs[0].nbDims;
|
||||
if (indice_dims == 1) {
|
||||
dims.d[0] = inputs[0].d[0];
|
||||
auto depth = exprBuilder.constant(depth_);
|
||||
dims.d[1] = depth;
|
||||
} else {
|
||||
if (axis_ == -1) {
|
||||
for (int i = 0; i != inputs[0].nbDims; ++i) {
|
||||
dims.d[i] = inputs[0].d[i];
|
||||
}
|
||||
auto depth = exprBuilder.constant(depth_);
|
||||
dims.d[inputs[0].nbDims] = depth;
|
||||
dims.d[dims.nbDims - 1] = depth;
|
||||
} else {
|
||||
for (int i = 0; i != indice_dims; ++i) {
|
||||
if (i >= axis_) {
|
||||
dims.d[i + 1] = inputs[0].d[i];
|
||||
} else {
|
||||
dims.d[i] = inputs[0].d[i];
|
||||
}
|
||||
}
|
||||
auto depth = exprBuilder.constant(depth_);
|
||||
dims.d[axis_] = depth;
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
@ -137,11 +152,15 @@ nvinfer1::IPluginV2DynamicExt *OnehotPlugin::clone() const noexcept {
|
|||
return plugin;
|
||||
}
|
||||
|
||||
size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 3 * sizeof(int); }
|
||||
size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 2 * sizeof(int); }
|
||||
|
||||
nvinfer1::DataType OnehotPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
|
||||
noexcept {
|
||||
return inputTypes[ON_VALUE_INDEX];
|
||||
}
|
||||
|
||||
void OnehotPlugin::serialize(void *buffer) const noexcept {
|
||||
SerializeValue(&buffer, &batch_dims_, sizeof(int));
|
||||
SerializeValue(&buffer, &feature_dims_, sizeof(int));
|
||||
SerializeValue(&buffer, &axis_, sizeof(int));
|
||||
SerializeValue(&buffer, &depth_, sizeof(int));
|
||||
SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType));
|
||||
}
|
||||
|
|
|
@ -41,20 +41,23 @@ class OnehotTensorRT : public TensorRTOp {
|
|||
constexpr char *ONEHOT_PLUGIN_NAME{"OnehotPlugin"};
|
||||
class OnehotPlugin : public TensorRTPlugin {
|
||||
public:
|
||||
OnehotPlugin(const std::string name, int depth, schema::PrimitiveType primitive_type)
|
||||
: TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)), depth_(depth), primitive_type_(primitive_type) {}
|
||||
OnehotPlugin(const std::string name, int axis, int depth, schema::PrimitiveType primitive_type)
|
||||
: TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)),
|
||||
axis_(axis),
|
||||
depth_(depth),
|
||||
primitive_type_(primitive_type) {}
|
||||
|
||||
OnehotPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
: TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) {
|
||||
const nvinfer1::PluginField *fields = fc->fields;
|
||||
depth_ = static_cast<const int *>(fields[0].data)[0];
|
||||
primitive_type_ = static_cast<const schema::PrimitiveType *>(fields[1].data)[0];
|
||||
axis_ = static_cast<const int *>(fields[0].data)[0];
|
||||
depth_ = static_cast<const int *>(fields[1].data)[0];
|
||||
primitive_type_ = static_cast<const schema::PrimitiveType *>(fields[2].data)[0];
|
||||
}
|
||||
|
||||
OnehotPlugin(const char *name, const void *serialData, size_t serialLength)
|
||||
: TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) {
|
||||
DeserializeValue(&serialData, &serialLength, &batch_dims_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &feature_dims_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &axis_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &depth_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType));
|
||||
}
|
||||
|
@ -65,6 +68,8 @@ class OnehotPlugin : public TensorRTPlugin {
|
|||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
|
||||
noexcept override;
|
||||
void serialize(void *buffer) const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
|
@ -72,10 +77,8 @@ class OnehotPlugin : public TensorRTPlugin {
|
|||
private:
|
||||
int RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,
|
||||
cudaStream_t stream);
|
||||
|
||||
int batch_dims_{1};
|
||||
int feature_dims_{1};
|
||||
int depth_{1};
|
||||
int axis_{-1};
|
||||
schema::PrimitiveType primitive_type_;
|
||||
};
|
||||
class OnehotPluginCreater : public TensorRTPluginCreater<OnehotPlugin> {
|
||||
|
|
|
@ -40,6 +40,22 @@ int WhereTensorRT::IsSupport(const schema::Primitive *primitive, const std::vect
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
nvinfer1::ITensor *WhereTensorRT::GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor) {
|
||||
auto input_cond_dims = input(ctx, 0).trt_tensor_->getDimensions();
|
||||
nvinfer1::Dims in_tensor_dims = input_tensor->getDimensions();
|
||||
while (in_tensor_dims.nbDims < input_cond_dims.nbDims) {
|
||||
input_tensor = ExpandDim(ctx, input_tensor, 0);
|
||||
if (input_tensor->getDimensions().nbDims == -1) {
|
||||
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
|
||||
}
|
||||
nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor);
|
||||
shuffle_layer->setReshapeDimensions(input_tensor->getDimensions());
|
||||
input_tensor = shuffle_layer->getOutput(0);
|
||||
in_tensor_dims = input_tensor->getDimensions();
|
||||
}
|
||||
return input_tensor;
|
||||
}
|
||||
|
||||
int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "network or input tensor is invalid";
|
||||
|
@ -61,13 +77,21 @@ int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
// broadcast to same shape
|
||||
if (input_x_dims.nbDims != input_y_dims.nbDims) {
|
||||
if (input_x_dims.nbDims > input_y_dims.nbDims) {
|
||||
auto expect_shape = ConvertMSShape(input(ctx, INPUT_X_INDEX).trt_tensor_->getDimensions());
|
||||
inputTensors[INPUT_Y_INDEX] =
|
||||
ConvertConstantTensorWithDims(ctx, in_tensors_[INPUT_Y_INDEX], expect_shape, op_name_ + "_broadcast_inputy");
|
||||
auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_X_INDEX).trt_tensor_)->getOutput(0);
|
||||
auto inputy = GetBroadcastTensor(ctx, input(ctx, INPUT_Y_INDEX).trt_tensor_);
|
||||
auto size_tensor = ctx->network()->addShape(*inputy)->getOutput(0);
|
||||
size_tensor = ctx->network()
|
||||
->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX)
|
||||
->getOutput(0);
|
||||
inputTensors[INPUT_Y_INDEX] = Broadcast(ctx, inputy, size_tensor);
|
||||
} else {
|
||||
auto expect_shape = ConvertMSShape(input(ctx, INPUT_Y_INDEX).trt_tensor_->getDimensions());
|
||||
inputTensors[INPUT_X_INDEX] =
|
||||
ConvertConstantTensorWithDims(ctx, in_tensors_[INPUT_X_INDEX], expect_shape, op_name_ + "_broadcast_inputx");
|
||||
auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_Y_INDEX).trt_tensor_)->getOutput(0);
|
||||
auto inputx = GetBroadcastTensor(ctx, input(ctx, INPUT_X_INDEX).trt_tensor_);
|
||||
auto size_tensor = ctx->network()->addShape(*inputx)->getOutput(0);
|
||||
size_tensor = ctx->network()
|
||||
->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX)
|
||||
->getOutput(0);
|
||||
inputTensors[INPUT_X_INDEX] = Broadcast(ctx, inputx, size_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,6 +150,11 @@ nvinfer1::IPluginV2DynamicExt *WherePlugin::clone() const noexcept {
|
|||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DataType WherePlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
|
||||
noexcept {
|
||||
return inputTypes[1];
|
||||
}
|
||||
|
||||
size_t WherePlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); }
|
||||
|
||||
void WherePlugin::serialize(void *buffer) const noexcept {
|
||||
|
|
|
@ -36,6 +36,9 @@ class WhereTensorRT : public TensorRTOp {
|
|||
|
||||
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) override;
|
||||
|
||||
private:
|
||||
nvinfer1::ITensor *GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor);
|
||||
};
|
||||
|
||||
constexpr char *WHERE_PLUGIN_NAME{"WherePlugin"};
|
||||
|
@ -62,6 +65,8 @@ class WherePlugin : public TensorRTPlugin {
|
|||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void *buffer) const noexcept override;
|
||||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
|
||||
noexcept override;
|
||||
|
||||
private:
|
||||
int RunCudaWhere(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,
|
||||
|
|
Loading…
Reference in New Issue