!15259 trt operator

From: @wilfchen
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-04-23 14:10:55 +08:00 committed by Gitee
commit 0fa86fb295
2 changed files with 364 additions and 17 deletions

View File

@ -38,7 +38,7 @@ class TrtUtils {
static std::map<nvinfer1::DataType, TypeId> type_list = {{nvinfer1::DataType::kFLOAT, TypeId::kNumberTypeFloat32},
{nvinfer1::DataType::kHALF, TypeId::kNumberTypeFloat16},
{nvinfer1::DataType::kINT8, TypeId::kNumberTypeInt8},
{nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt}};
{nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt32}};
auto iter = type_list.find(trt_dtype);
if (iter == type_list.end()) {
@ -51,7 +51,8 @@ class TrtUtils {
static std::map<TypeId, nvinfer1::DataType> type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
{TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
{TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8},
{TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32}};
{TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32},
{TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}};
auto iter = type_list.find(ms_dtype);
if (iter == type_list.end()) {
MS_LOG(EXCEPTION) << "data type not support: " << ms_dtype;
@ -69,7 +70,7 @@ class TrtUtils {
return trt_dims;
}
static nvinfer1::Dims TrtDimsToMsDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) {
static nvinfer1::Dims MsDimsToTrtDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) {
nvinfer1::Dims trt_dims;
size_t offset = ignore_batch_dim ? 1 : 0;
for (size_t i = offset; i < ms_shape.size(); ++i) {

View File

@ -24,6 +24,19 @@
namespace mindspore {
namespace opt {
namespace {
nvinfer1::ITensor *ToTensor(LayerInput *input, const std::vector<size_t> &shape,
std::shared_ptr<TrtConverterContext> context) {
MS_EXCEPTION_IF_NULL(context);
if (input->IsTensor()) {
return input->tensor();
}
const nvinfer1::Dims &dim = TrtUtils::MsDimsToTrtDims(shape, false);
auto *const_layer = context->network()->addConstant(dim, *input->weight());
MS_EXCEPTION_IF_NULL(const_layer);
return const_layer->getOutput(0);
}
ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
@ -34,15 +47,7 @@ ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
auto *layer = context->network()->addShuffle(*inputs[0].tensor());
MS_EXCEPTION_IF_NULL(layer);
const auto &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
const auto &output_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (input_shape[0] != output_shape[0]) {
MS_LOG(ERROR) << "Reshape does not support modify batch size. Input batch size: " << input_shape[0]
<< "Output batch size: " << output_shape[0];
return {false, {}};
}
const nvinfer1::Dims &dims = TrtUtils::MsDimsToTrtDims(output_shape, false);
layer->setReshapeDimensions(dims);
@ -62,7 +67,6 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
const std::vector<size_t> &x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
const std::vector<size_t> &y_shape = AnfAlgo::GetOutputInferShape(node, 0);
// Keep to output
auto Broadcast = [&context, &y_shape](nvinfer1::ITensor *tensor, const std::vector<size_t> &x_shape) {
if (x_shape.size() == y_shape.size()) {
return tensor;
@ -88,8 +92,8 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
return layer->getOutput(0);
};
auto *x1 = Broadcast(inputs[0].tensor(), x1_shape);
auto *x2 = Broadcast(inputs[1].tensor(), x2_shape);
auto *x1 = Broadcast(ToTensor(&inputs[0], x1_shape, context), x1_shape);
auto *x2 = Broadcast(ToTensor(&inputs[1], x2_shape, context), x2_shape);
auto *layer = context->network()->addElementWise(*x1, *x2, op_type);
MS_EXCEPTION_IF_NULL(layer);
@ -142,6 +146,80 @@ ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr<TrtConverterCo
return {true, {LayerInput(layer->getOutput(0))}};
}
ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
nvinfer1::UnaryOperation op_type) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
return {false, {}};
}
auto *layer = context->network()->addUnary(*inputs[0].tensor(), op_type);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
}
ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
nvinfer1::ReduceOperation op_type) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
return {false, {}};
}
// Calculate reduce axes bitmask
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
const ValuePtr &value = AnfAlgo::GetCNodePrimitive(node)->GetAttr("axis");
uint32_t reduce_axes = 0;
if (value->isa<ValueTuple>() || value->isa<ValueList>()) {
const auto &axis = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "axis");
for (size_t i = 0; i < axis.size(); i++) {
int offset = axis[i] >= 0 ? LongToInt(axis[i]) : LongToInt(axis[i] + input_shape.size());
reduce_axes |= 1UL << offset;
}
} else {
const auto &axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size());
reduce_axes = 1UL << offset;
}
// Tensor-RT do not support reduce with no dimensions.
// Skip reduce operator if reduce_axes == 0
if (reduce_axes == 0) {
MS_LOG(WARNING) << "No dimension be be reduced. " << node->DebugString();
return {true, {LayerInput(inputs[0].tensor())}};
}
bool keep_dims = AnfAlgo::GetNodeAttr<bool>(node, "keep_dims");
// Tensor-RT do not support reduce all dimensions with keep_dims == false.
// Reduce with keep_dims = true, add apply reshape latter.
bool post_reshape = false;
if (keep_dims == false && (reduce_axes == (1UL << input_shape.size()) - 1)) {
keep_dims = true;
post_reshape = true;
}
nvinfer1::IReduceLayer *layer = context->network()->addReduce(*inputs[0].tensor(), op_type, reduce_axes, keep_dims);
MS_EXCEPTION_IF_NULL(layer);
if (post_reshape) {
nvinfer1::IShuffleLayer *reshape_layer = context->network()->addShuffle(*layer->getOutput(0));
MS_EXCEPTION_IF_NULL(reshape_layer);
nvinfer1::Dims dim;
dim.nbDims = 1;
dim.d[1] = 1;
reshape_layer->setReshapeDimensions(dim);
return {true, {LayerInput(reshape_layer->getOutput(0))}};
}
return {true, {LayerInput(layer->getOutput(0))}};
}
} // namespace
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
@ -195,6 +273,7 @@ MS_TRT_CONVERTER_FUNC_REG(Add) { return AddElementLayer(node, context, nvinfer1:
MS_TRT_CONVERTER_FUNC_REG(Sub) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kSUB); }
MS_TRT_CONVERTER_FUNC_REG(Mul) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kPROD); }
MS_TRT_CONVERTER_FUNC_REG(Div) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kDIV); }
MS_TRT_CONVERTER_FUNC_REG(RealDiv) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kDIV); }
MS_TRT_CONVERTER_FUNC_REG(Pow) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kPOW); }
MS_TRT_CONVERTER_FUNC_REG(Maximum) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kMAX); }
MS_TRT_CONVERTER_FUNC_REG(Minimum) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kMIN); }
@ -202,6 +281,33 @@ MS_TRT_CONVERTER_FUNC_REG(FloorDiv) {
return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kFLOOR_DIV);
}
// Unary operators
MS_TRT_CONVERTER_FUNC_REG(Exp) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kEXP); }
MS_TRT_CONVERTER_FUNC_REG(Log) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kLOG); }
MS_TRT_CONVERTER_FUNC_REG(Sqrt) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSQRT); }
MS_TRT_CONVERTER_FUNC_REG(Reciprocal) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kRECIP); }
MS_TRT_CONVERTER_FUNC_REG(Abs) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kABS); }
MS_TRT_CONVERTER_FUNC_REG(Neg) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kNEG); }
MS_TRT_CONVERTER_FUNC_REG(Sin) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSIN); }
MS_TRT_CONVERTER_FUNC_REG(COS) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCOS); }
MS_TRT_CONVERTER_FUNC_REG(Tan) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kTAN); }
MS_TRT_CONVERTER_FUNC_REG(Sinh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSINH); }
MS_TRT_CONVERTER_FUNC_REG(Cosh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCOSH); }
MS_TRT_CONVERTER_FUNC_REG(Asin) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kASIN); }
MS_TRT_CONVERTER_FUNC_REG(Acos) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kACOS); }
MS_TRT_CONVERTER_FUNC_REG(Atan) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kATAN); }
MS_TRT_CONVERTER_FUNC_REG(Asinh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kASINH); }
MS_TRT_CONVERTER_FUNC_REG(Acosh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kACOSH); }
MS_TRT_CONVERTER_FUNC_REG(Ceil) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCEIL); }
MS_TRT_CONVERTER_FUNC_REG(Floor) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kFLOOR); }
// Reduce operators
MS_TRT_CONVERTER_FUNC_REG(ReduceSum) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kSUM); }
MS_TRT_CONVERTER_FUNC_REG(ReduceMean) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kAVG); }
MS_TRT_CONVERTER_FUNC_REG(ReduceMax) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kMAX); }
MS_TRT_CONVERTER_FUNC_REG(ReduceMin) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kMIN); }
MS_TRT_CONVERTER_FUNC_REG(ReduceProd) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kPROD); }
// Pooling operators.
MS_TRT_CONVERTER_FUNC_REG(AvgPool) { return AddPoolingLayer(node, context, nvinfer1::PoolingType::kAVERAGE); }
MS_TRT_CONVERTER_FUNC_REG(MaxPool) { return AddPoolingLayer(node, context, nvinfer1::PoolingType::kMAX); }
@ -304,6 +410,45 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) {
}
}
MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 2) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
return {false, {}};
}
const auto &transpose_a = AnfAlgo::GetNodeAttr<bool>(node, "transpose_a");
const auto &transpose_b = AnfAlgo::GetNodeAttr<bool>(node, "transpose_b");
const auto &trt_transpose1 = transpose_a ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
const auto &trt_transpose2 = transpose_b ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
auto SwapLastDims = [](std::vector<size_t> shape, const bool &transpose) {
if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Operation not support: input rank should >= 2";
}
if (!transpose) {
return shape;
}
size_t tmp = shape[shape.size() - 2];
shape[shape.size() - 2] = shape[shape.size() - 1];
shape[shape.size() - 1] = tmp;
return shape;
};
nvinfer1::ITensor *tensor1 = ToTensor(&inputs[0], SwapLastDims(shape1, transpose_a), context);
nvinfer1::ITensor *tensor2 = ToTensor(&inputs[1], SwapLastDims(shape2, transpose_b), context);
auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
@ -321,7 +466,7 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
return {false, {}};
}
// Convert Weight to ITensor which
// Convert Weight to ITensor
nvinfer1::Dims unsqueeze_bias_dims;
unsqueeze_bias_dims.nbDims = x_shape.size();
std::fill(unsqueeze_bias_dims.d, unsqueeze_bias_dims.d + unsqueeze_bias_dims.nbDims, 1);
@ -335,10 +480,9 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
return {true, {LayerInput(layer->getOutput(0))}};
}
// NoOp
MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(ExpandDims) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(Squeeze) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(BatchNorm) {
@ -466,5 +610,207 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) {
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(Slice) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
return {false, {}};
}
const auto &begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "begin");
const auto &size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "size");
nvinfer1::Dims trt_start = TrtUtils::MsDimsToTrtDims(begin, false);
nvinfer1::Dims trt_size = TrtUtils::MsDimsToTrtDims(size, false);
nvinfer1::Dims trt_stride;
for (int32_t i = 0; i < trt_start.nbDims; i++) {
trt_stride.d[trt_stride.nbDims++] = 1;
}
auto *layer = context->network()->addSlice(*inputs[0].tensor(), trt_start, trt_size, trt_stride);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(Transpose) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
return {false, {}};
}
const auto &perm = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "perm");
nvinfer1::Permutation trt_perm;
for (size_t i = 0; i < perm.size(); i++) {
trt_perm.order[i] = LongToInt(perm[i]);
}
auto *layer = context->network()->addShuffle(*inputs[0].tensor());
MS_EXCEPTION_IF_NULL(layer);
layer->setFirstTranspose(trt_perm);
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(Softmax) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
return {false, {}};
}
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
const ValuePtr &value = AnfAlgo::GetCNodePrimitive(node)->GetAttr("axis");
uint32_t reduce_axes = 0;
if (value->isa<ValueTuple>() || value->isa<ValueList>()) {
const auto &axis = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "axis");
if (axis.size() != 1) {
MS_LOG(ERROR) << "Only one axis can be set. Axis size" << axis.size();
return {false, {}};
}
int offset = axis[0] >= 0 ? LongToInt(axis[0]) : LongToInt(axis[0] + input_shape.size());
reduce_axes = 1U << offset;
} else {
const auto &axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size());
reduce_axes = 1UL << offset;
}
auto *layer = context->network()->addSoftMax(*inputs[0].tensor());
MS_EXCEPTION_IF_NULL(layer);
layer->setAxes(reduce_axes);
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
return {false, {}};
}
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
const auto &axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size());
uint32_t reduce_axes = 1UL << offset;
auto *softmax_layer = context->network()->addSoftMax(*inputs[0].tensor());
MS_EXCEPTION_IF_NULL(softmax_layer);
softmax_layer->setAxes(reduce_axes);
auto *log_layer = context->network()->addUnary(*softmax_layer->getOutput(0), nvinfer1::UnaryOperation::kLOG);
MS_EXCEPTION_IF_NULL(log_layer);
return {true, {LayerInput(log_layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(Gather) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 2) {
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected.";
return {false, {}};
}
const std::vector<size_t> &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
auto axis = AnfAlgo::GetNodeAttr<int64_t>(node, "axis");
axis = axis >= 0 ? axis : axis + input_shape.size();
nvinfer1::ITensor *input = ToTensor(&inputs[0], input_shape, context);
const std::vector<size_t> &indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
nvinfer1::ITensor *indices = ToTensor(&inputs[1], indices_shape, context);
auto *layer = context->network()->addGather(*input, *indices, LongToInt(axis));
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(Cast) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) {
MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size();
return {false, {}};
}
const TypeId &dst_type = AnfAlgo::GetOutputInferDataType(node, 0);
auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type);
auto *layer = context->network()->addIdentity(*inputs[0].tensor());
layer->setOutputType(0, trt_type);
return {true, {LayerInput(layer->getOutput(0))}};
}
MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret || inputs.size() != 3 || !inputs[0].IsTensor() || !inputs[1].IsWeight() || !inputs[2].IsWeight()) {
MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size();
return {false, {}};
}
// Calculate reduce axes
const std::vector<size_t> &input_shape = AnfAlgo::GetOutputInferShape(node, 0);
auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(node, "begin_norm_axis");
begin_norm_axis = begin_norm_axis >= 0 ? begin_norm_axis : begin_norm_axis + input_shape.size();
uint32_t reduce_axes = 0;
for (size_t i = LongToSize(begin_norm_axis); i < input_shape.size(); i++) {
reduce_axes |= 1UL << i;
}
// Reshape gamma and beta for broadcast
auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(node, "begin_params_axis");
begin_params_axis = begin_params_axis >= 0 ? begin_params_axis : begin_params_axis + input_shape.size();
std::vector<size_t> param_shape = input_shape;
for (size_t j = 0; j < LongToSize(begin_params_axis); j++) {
param_shape[j] = 1;
}
auto epsilon = AnfAlgo::GetNodeAttr<float>(node, "epsilon");
std::shared_ptr<tensor::Tensor> weight = context->CreateTempWeight(kNumberTypeFloat32, {1});
auto value = static_cast<float *>(weight->data_c());
value[0] = epsilon;
nvinfer1::Dims dim;
dim.nbDims = SizeToInt(input_shape.size());
for (size_t i = 0; i < input_shape.size(); i++) {
dim.d[i] = 1;
}
auto *epsilon_layer = context->network()->addConstant(dim, nvinfer1::Weights{nvinfer1::DataType::kFLOAT, value, 1});
MS_EXCEPTION_IF_NULL(epsilon_layer);
// y = (x - mean) / sqrt(var) * gamma + beta
auto *mean = context->network()->addReduce(*inputs[0].tensor(), nvinfer1::ReduceOperation::kAVG, reduce_axes, true);
MS_EXCEPTION_IF_NULL(mean);
auto *sub =
context->network()->addElementWise(*inputs[0].tensor(), *mean->getOutput(0), nvinfer1::ElementWiseOperation::kSUB);
MS_EXCEPTION_IF_NULL(sub);
auto *pow =
context->network()->addElementWise(*sub->getOutput(0), *sub->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
MS_EXCEPTION_IF_NULL(pow);
auto *var = context->network()->addReduce(*pow->getOutput(0), nvinfer1::ReduceOperation::kAVG, reduce_axes, true);
MS_EXCEPTION_IF_NULL(var);
auto *var_epsilon = context->network()->addElementWise(*var->getOutput(0), *epsilon_layer->getOutput(0),
nvinfer1::ElementWiseOperation::kSUM);
MS_EXCEPTION_IF_NULL(var_epsilon);
auto *std = context->network()->addUnary(*var_epsilon->getOutput(0), nvinfer1::UnaryOperation::kSQRT);
MS_EXCEPTION_IF_NULL(std);
auto *div =
context->network()->addElementWise(*sub->getOutput(0), *std->getOutput(0), nvinfer1::ElementWiseOperation::kDIV);
MS_EXCEPTION_IF_NULL(div);
auto *mul = context->network()->addElementWise(*div->getOutput(0), *ToTensor(&inputs[1], param_shape, context),
nvinfer1::ElementWiseOperation::kPROD);
MS_EXCEPTION_IF_NULL(mul);
auto *add = context->network()->addElementWise(*mul->getOutput(0), *ToTensor(&inputs[2], param_shape, context),
nvinfer1::ElementWiseOperation::kSUM);
MS_EXCEPTION_IF_NULL(add);
return {true, {LayerInput(add->getOutput(0))}};
}
} // namespace opt
} // namespace mindspore