forked from mindspore-Ecosystem/mindspore
!15259 trt operator
From: @wilfchen Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
0fa86fb295
|
@ -38,7 +38,7 @@ class TrtUtils {
|
|||
static std::map<nvinfer1::DataType, TypeId> type_list = {{nvinfer1::DataType::kFLOAT, TypeId::kNumberTypeFloat32},
|
||||
{nvinfer1::DataType::kHALF, TypeId::kNumberTypeFloat16},
|
||||
{nvinfer1::DataType::kINT8, TypeId::kNumberTypeInt8},
|
||||
{nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt}};
|
||||
{nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt32}};
|
||||
|
||||
auto iter = type_list.find(trt_dtype);
|
||||
if (iter == type_list.end()) {
|
||||
|
@ -51,7 +51,8 @@ class TrtUtils {
|
|||
static std::map<TypeId, nvinfer1::DataType> type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
|
||||
{TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
|
||||
{TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8},
|
||||
{TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32}};
|
||||
{TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32},
|
||||
{TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}};
|
||||
auto iter = type_list.find(ms_dtype);
|
||||
if (iter == type_list.end()) {
|
||||
MS_LOG(EXCEPTION) << "data type not support: " << ms_dtype;
|
||||
|
@ -69,7 +70,7 @@ class TrtUtils {
|
|||
return trt_dims;
|
||||
}
|
||||
|
||||
static nvinfer1::Dims TrtDimsToMsDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) {
|
||||
static nvinfer1::Dims MsDimsToTrtDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) {
|
||||
nvinfer1::Dims trt_dims;
|
||||
size_t offset = ignore_batch_dim ? 1 : 0;
|
||||
for (size_t i = offset; i < ms_shape.size(); ++i) {
|
||||
|
|
|
@ -24,6 +24,19 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
nvinfer1::ITensor *ToTensor(LayerInput *input, const std::vector<size_t> &shape,
|
||||
std::shared_ptr<TrtConverterContext> context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input->IsTensor()) {
|
||||
return input->tensor();
|
||||
}
|
||||
|
||||
const nvinfer1::Dims &dim = TrtUtils::MsDimsToTrtDims(shape, false);
|
||||
auto *const_layer = context->network()->addConstant(dim, *input->weight());
|
||||
MS_EXCEPTION_IF_NULL(const_layer);
|
||||
return const_layer->getOutput(0);
|
||||
}
|
||||
|
||||
ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
|
@ -34,15 +47,7 @@ ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
|
|||
|
||||
auto *layer = context->network()->addShuffle(*inputs[0].tensor());
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
const auto &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
const auto &output_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (input_shape[0] != output_shape[0]) {
|
||||
MS_LOG(ERROR) << "Reshape does not support modify batch size. Input batch size: " << input_shape[0]
|
||||
<< "Output batch size: " << output_shape[0];
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const nvinfer1::Dims &dims = TrtUtils::MsDimsToTrtDims(output_shape, false);
|
||||
layer->setReshapeDimensions(dims);
|
||||
|
||||
|
@ -62,7 +67,6 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
|
|||
const std::vector<size_t> &x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
|
||||
const std::vector<size_t> &y_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
|
||||
// Keep to output
|
||||
auto Broadcast = [&context, &y_shape](nvinfer1::ITensor *tensor, const std::vector<size_t> &x_shape) {
|
||||
if (x_shape.size() == y_shape.size()) {
|
||||
return tensor;
|
||||
|
@ -88,8 +92,8 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
|
|||
return layer->getOutput(0);
|
||||
};
|
||||
|
||||
auto *x1 = Broadcast(inputs[0].tensor(), x1_shape);
|
||||
auto *x2 = Broadcast(inputs[1].tensor(), x2_shape);
|
||||
auto *x1 = Broadcast(ToTensor(&inputs[0], x1_shape, context), x1_shape);
|
||||
auto *x2 = Broadcast(ToTensor(&inputs[1], x2_shape, context), x2_shape);
|
||||
auto *layer = context->network()->addElementWise(*x1, *x2, op_type);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
|
@ -142,6 +146,80 @@ ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr<TrtConverterCo
|
|||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
nvinfer1::UnaryOperation op_type) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
auto *layer = context->network()->addUnary(*inputs[0].tensor(), op_type);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
nvinfer1::ReduceOperation op_type) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
// Calculate reduce axes bitmask
|
||||
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
const ValuePtr &value = AnfAlgo::GetCNodePrimitive(node)->GetAttr("axis");
|
||||
uint32_t reduce_axes = 0;
|
||||
if (value->isa<ValueTuple>() || value->isa<ValueList>()) {
|
||||
const auto &axis = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "axis");
|
||||
for (size_t i = 0; i < axis.size(); i++) {
|
||||
int offset = axis[i] >= 0 ? LongToInt(axis[i]) : LongToInt(axis[i] + input_shape.size());
|
||||
reduce_axes |= 1UL << offset;
|
||||
}
|
||||
} else {
|
||||
const auto &axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
|
||||
int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size());
|
||||
reduce_axes = 1UL << offset;
|
||||
}
|
||||
|
||||
// Tensor-RT do not support reduce with no dimensions.
|
||||
// Skip reduce operator if reduce_axes == 0
|
||||
if (reduce_axes == 0) {
|
||||
MS_LOG(WARNING) << "No dimension be be reduced. " << node->DebugString();
|
||||
return {true, {LayerInput(inputs[0].tensor())}};
|
||||
}
|
||||
|
||||
bool keep_dims = AnfAlgo::GetNodeAttr<bool>(node, "keep_dims");
|
||||
// Tensor-RT do not support reduce all dimensions with keep_dims == false.
|
||||
// Reduce with keep_dims = true, add apply reshape latter.
|
||||
bool post_reshape = false;
|
||||
if (keep_dims == false && (reduce_axes == (1UL << input_shape.size()) - 1)) {
|
||||
keep_dims = true;
|
||||
post_reshape = true;
|
||||
}
|
||||
|
||||
nvinfer1::IReduceLayer *layer = context->network()->addReduce(*inputs[0].tensor(), op_type, reduce_axes, keep_dims);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
if (post_reshape) {
|
||||
nvinfer1::IShuffleLayer *reshape_layer = context->network()->addShuffle(*layer->getOutput(0));
|
||||
MS_EXCEPTION_IF_NULL(reshape_layer);
|
||||
|
||||
nvinfer1::Dims dim;
|
||||
dim.nbDims = 1;
|
||||
dim.d[1] = 1;
|
||||
reshape_layer->setReshapeDimensions(dim);
|
||||
|
||||
return {true, {LayerInput(reshape_layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
|
||||
|
@ -195,6 +273,7 @@ MS_TRT_CONVERTER_FUNC_REG(Add) { return AddElementLayer(node, context, nvinfer1:
|
|||
MS_TRT_CONVERTER_FUNC_REG(Sub) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kSUB); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Mul) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kPROD); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Div) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kDIV); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(RealDiv) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kDIV); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Pow) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kPOW); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Maximum) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kMAX); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Minimum) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kMIN); }
|
||||
|
@ -202,6 +281,33 @@ MS_TRT_CONVERTER_FUNC_REG(FloorDiv) {
|
|||
return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kFLOOR_DIV);
|
||||
}
|
||||
|
||||
// Unary operators
|
||||
MS_TRT_CONVERTER_FUNC_REG(Exp) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kEXP); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Log) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kLOG); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Sqrt) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSQRT); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Reciprocal) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kRECIP); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Abs) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kABS); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Neg) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kNEG); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Sin) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSIN); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(COS) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCOS); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Tan) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kTAN); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Sinh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSINH); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Cosh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCOSH); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Asin) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kASIN); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Acos) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kACOS); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Atan) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kATAN); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Asinh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kASINH); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Acosh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kACOSH); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Ceil) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCEIL); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Floor) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kFLOOR); }
|
||||
|
||||
// Reduce operators
|
||||
MS_TRT_CONVERTER_FUNC_REG(ReduceSum) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kSUM); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(ReduceMean) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kAVG); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(ReduceMax) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kMAX); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(ReduceMin) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kMIN); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(ReduceProd) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kPROD); }
|
||||
|
||||
// Pooling operators.
|
||||
MS_TRT_CONVERTER_FUNC_REG(AvgPool) { return AddPoolingLayer(node, context, nvinfer1::PoolingType::kAVERAGE); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(MaxPool) { return AddPoolingLayer(node, context, nvinfer1::PoolingType::kMAX); }
|
||||
|
@ -304,6 +410,45 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) {
|
|||
}
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 2) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const auto &transpose_a = AnfAlgo::GetNodeAttr<bool>(node, "transpose_a");
|
||||
const auto &transpose_b = AnfAlgo::GetNodeAttr<bool>(node, "transpose_b");
|
||||
const auto &trt_transpose1 = transpose_a ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
|
||||
const auto &trt_transpose2 = transpose_b ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
|
||||
|
||||
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
|
||||
|
||||
auto SwapLastDims = [](std::vector<size_t> shape, const bool &transpose) {
|
||||
if (shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Operation not support: input rank should >= 2";
|
||||
}
|
||||
|
||||
if (!transpose) {
|
||||
return shape;
|
||||
}
|
||||
|
||||
size_t tmp = shape[shape.size() - 2];
|
||||
shape[shape.size() - 2] = shape[shape.size() - 1];
|
||||
shape[shape.size() - 1] = tmp;
|
||||
return shape;
|
||||
};
|
||||
|
||||
nvinfer1::ITensor *tensor1 = ToTensor(&inputs[0], SwapLastDims(shape1, transpose_a), context);
|
||||
nvinfer1::ITensor *tensor2 = ToTensor(&inputs[1], SwapLastDims(shape2, transpose_b), context);
|
||||
auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
|
@ -321,7 +466,7 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
|
|||
return {false, {}};
|
||||
}
|
||||
|
||||
// Convert Weight to ITensor which
|
||||
// Convert Weight to ITensor
|
||||
nvinfer1::Dims unsqueeze_bias_dims;
|
||||
unsqueeze_bias_dims.nbDims = x_shape.size();
|
||||
std::fill(unsqueeze_bias_dims.d, unsqueeze_bias_dims.d + unsqueeze_bias_dims.nbDims, 1);
|
||||
|
@ -335,10 +480,9 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
|
|||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
// NoOp
|
||||
MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); }
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(ExpandDims) { return AddReshapeLayer(node, context); }
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Squeeze) { return AddReshapeLayer(node, context); }
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(BatchNorm) {
|
||||
|
@ -466,5 +610,207 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) {
|
|||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Slice) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const auto &begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "begin");
|
||||
const auto &size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "size");
|
||||
|
||||
nvinfer1::Dims trt_start = TrtUtils::MsDimsToTrtDims(begin, false);
|
||||
nvinfer1::Dims trt_size = TrtUtils::MsDimsToTrtDims(size, false);
|
||||
nvinfer1::Dims trt_stride;
|
||||
for (int32_t i = 0; i < trt_start.nbDims; i++) {
|
||||
trt_stride.d[trt_stride.nbDims++] = 1;
|
||||
}
|
||||
|
||||
auto *layer = context->network()->addSlice(*inputs[0].tensor(), trt_start, trt_size, trt_stride);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Transpose) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const auto &perm = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "perm");
|
||||
nvinfer1::Permutation trt_perm;
|
||||
for (size_t i = 0; i < perm.size(); i++) {
|
||||
trt_perm.order[i] = LongToInt(perm[i]);
|
||||
}
|
||||
|
||||
auto *layer = context->network()->addShuffle(*inputs[0].tensor());
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
layer->setFirstTranspose(trt_perm);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Softmax) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
const ValuePtr &value = AnfAlgo::GetCNodePrimitive(node)->GetAttr("axis");
|
||||
uint32_t reduce_axes = 0;
|
||||
if (value->isa<ValueTuple>() || value->isa<ValueList>()) {
|
||||
const auto &axis = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "axis");
|
||||
if (axis.size() != 1) {
|
||||
MS_LOG(ERROR) << "Only one axis can be set. Axis size" << axis.size();
|
||||
return {false, {}};
|
||||
}
|
||||
int offset = axis[0] >= 0 ? LongToInt(axis[0]) : LongToInt(axis[0] + input_shape.size());
|
||||
reduce_axes = 1U << offset;
|
||||
} else {
|
||||
const auto &axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
|
||||
int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size());
|
||||
reduce_axes = 1UL << offset;
|
||||
}
|
||||
|
||||
auto *layer = context->network()->addSoftMax(*inputs[0].tensor());
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
layer->setAxes(reduce_axes);
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
const auto &axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
|
||||
int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size());
|
||||
uint32_t reduce_axes = 1UL << offset;
|
||||
|
||||
auto *softmax_layer = context->network()->addSoftMax(*inputs[0].tensor());
|
||||
MS_EXCEPTION_IF_NULL(softmax_layer);
|
||||
softmax_layer->setAxes(reduce_axes);
|
||||
|
||||
auto *log_layer = context->network()->addUnary(*softmax_layer->getOutput(0), nvinfer1::UnaryOperation::kLOG);
|
||||
MS_EXCEPTION_IF_NULL(log_layer);
|
||||
|
||||
return {true, {LayerInput(log_layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Gather) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 2) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
auto axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
|
||||
axis = axis >= 0 ? axis : axis + input_shape.size();
|
||||
|
||||
nvinfer1::ITensor *input = ToTensor(&inputs[0], input_shape, context);
|
||||
const std::vector<size_t> &indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
|
||||
nvinfer1::ITensor *indices = ToTensor(&inputs[1], indices_shape, context);
|
||||
|
||||
auto *layer = context->network()->addGather(*input, *indices, LongToInt(axis));
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Cast) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
|
||||
MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size();
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const TypeId &dst_type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type);
|
||||
auto *layer = context->network()->addIdentity(*inputs[0].tensor());
|
||||
layer->setOutputType(0, trt_type);
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 3 || !inputs[0].IsTensor() || !inputs[1].IsWeight() || !inputs[2].IsWeight()) {
|
||||
MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size();
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
// Calculate reduce axes
|
||||
const std::vector<size_t> &input_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(node, "begin_norm_axis");
|
||||
begin_norm_axis = begin_norm_axis >= 0 ? begin_norm_axis : begin_norm_axis + input_shape.size();
|
||||
uint32_t reduce_axes = 0;
|
||||
for (size_t i = LongToSize(begin_norm_axis); i < input_shape.size(); i++) {
|
||||
reduce_axes |= 1UL << i;
|
||||
}
|
||||
|
||||
// Reshape gamma and beta for broadcast
|
||||
auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(node, "begin_params_axis");
|
||||
begin_params_axis = begin_params_axis >= 0 ? begin_params_axis : begin_params_axis + input_shape.size();
|
||||
std::vector<size_t> param_shape = input_shape;
|
||||
for (size_t j = 0; j < LongToSize(begin_params_axis); j++) {
|
||||
param_shape[j] = 1;
|
||||
}
|
||||
|
||||
auto epsilon = AnfAlgo::GetNodeAttr<float>(node, "epsilon");
|
||||
std::shared_ptr<tensor::Tensor> weight = context->CreateTempWeight(kNumberTypeFloat32, {1});
|
||||
auto value = static_cast<float *>(weight->data_c());
|
||||
value[0] = epsilon;
|
||||
nvinfer1::Dims dim;
|
||||
dim.nbDims = SizeToInt(input_shape.size());
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
dim.d[i] = 1;
|
||||
}
|
||||
auto *epsilon_layer = context->network()->addConstant(dim, nvinfer1::Weights{nvinfer1::DataType::kFLOAT, value, 1});
|
||||
MS_EXCEPTION_IF_NULL(epsilon_layer);
|
||||
|
||||
// y = (x - mean) / sqrt(var) * gamma + beta
|
||||
auto *mean = context->network()->addReduce(*inputs[0].tensor(), nvinfer1::ReduceOperation::kAVG, reduce_axes, true);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
auto *sub =
|
||||
context->network()->addElementWise(*inputs[0].tensor(), *mean->getOutput(0), nvinfer1::ElementWiseOperation::kSUB);
|
||||
MS_EXCEPTION_IF_NULL(sub);
|
||||
auto *pow =
|
||||
context->network()->addElementWise(*sub->getOutput(0), *sub->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
|
||||
MS_EXCEPTION_IF_NULL(pow);
|
||||
auto *var = context->network()->addReduce(*pow->getOutput(0), nvinfer1::ReduceOperation::kAVG, reduce_axes, true);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
auto *var_epsilon = context->network()->addElementWise(*var->getOutput(0), *epsilon_layer->getOutput(0),
|
||||
nvinfer1::ElementWiseOperation::kSUM);
|
||||
MS_EXCEPTION_IF_NULL(var_epsilon);
|
||||
auto *std = context->network()->addUnary(*var_epsilon->getOutput(0), nvinfer1::UnaryOperation::kSQRT);
|
||||
MS_EXCEPTION_IF_NULL(std);
|
||||
auto *div =
|
||||
context->network()->addElementWise(*sub->getOutput(0), *std->getOutput(0), nvinfer1::ElementWiseOperation::kDIV);
|
||||
MS_EXCEPTION_IF_NULL(div);
|
||||
auto *mul = context->network()->addElementWise(*div->getOutput(0), *ToTensor(&inputs[1], param_shape, context),
|
||||
nvinfer1::ElementWiseOperation::kPROD);
|
||||
MS_EXCEPTION_IF_NULL(mul);
|
||||
auto *add = context->network()->addElementWise(*mul->getOutput(0), *ToTensor(&inputs[2], param_shape, context),
|
||||
nvinfer1::ElementWiseOperation::kSUM);
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
|
||||
return {true, {LayerInput(add->getOutput(0))}};
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue