delete enable_fused_layernorm

This commit is contained in:
yoonlee666 2020-09-03 11:40:38 +08:00
parent 4ec343961e
commit dfd85caa1b
10 changed files with 29 additions and 315 deletions

View File

@ -161,7 +161,6 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
├─dataset.py # data preprocessing
├─finetune_eval_config.py # parameter configuration for finetuning
├─finetune_eval_model.py # backbone code of network
├─fused_layer_norm.py # Layernormal is optimized for Ascend
├─sample_process.py # sample processing
├─utils.py # util function
├─pretrain_eval.py # train and eval net

View File

@ -25,7 +25,6 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .fused_layer_norm import FusedLayerNorm
class BertConfig:
@ -78,8 +77,7 @@ class BertConfig:
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
@ -98,7 +96,6 @@ class BertConfig:
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
self.enable_fused_layernorm = enable_fused_layernorm
class EmbeddingLookup(nn.Cell):
@ -245,19 +242,14 @@ class BertOutput(nn.Cell):
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.TensorAdd()
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
@ -615,8 +607,7 @@ class BertSelfAttention(nn.Cell):
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
@ -644,8 +635,7 @@ class BertSelfAttention(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
@ -687,8 +677,7 @@ class BertEncoderCell(nn.Cell):
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
@ -700,8 +689,7 @@ class BertEncoderCell(nn.Cell):
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
@ -710,8 +698,7 @@ class BertEncoderCell(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
def construct(self, hidden_states, attention_mask):
# self-attention
@ -758,8 +745,7 @@ class BertTransformer(nn.Cell):
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False,
enable_fused_layernorm=False):
return_all_encoders=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
@ -776,8 +762,7 @@ class BertTransformer(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
@ -904,8 +889,7 @@ class BertModel(nn.Cell):
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True,
enable_fused_layernorm=config.enable_fused_layernorm)
return_all_encoders=True)
self.cast = P.Cast()
self.dtype = config.dtype

View File

@ -1,122 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.common.dtype as mstype
from mindspore.nn.cell import Cell
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
def construct(self, input_x):
"""Applies Layer Normalization over a mini-batch of inputs"""
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s

View File

@ -113,7 +113,6 @@ For example, the dataset is cn-wiki-128, the schema file for general distill pha
├─__init__.py
├─assessment_method.py # assessment method for evaluation
├─dataset.py # data processing
├─fused_layer_norm.py # Layernormal is optimized for Ascend
├─gd_config.py # parameter configuration for general distill phase
├─td_config.py # parameter configuration for task distill phase
├─tinybert_for_gd_td.py # backbone code of network
@ -229,7 +228,6 @@ Parameters for bert network:
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False
```
## [Training Process](#contents)
### Training

View File

@ -1,122 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.common.dtype as mstype
from mindspore.nn.cell import Cell
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
def construct(self, input_x):
"""fusedlayernorm"""
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s

View File

@ -55,8 +55,7 @@ bert_teacher_net_cfg = BertConfig(
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=False
compute_type=mstype.float16
)
bert_student_net_cfg = BertConfig(
batch_size=32,
@ -76,6 +75,5 @@ bert_student_net_cfg = BertConfig(
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=False
compute_type=mstype.float16
)

View File

@ -74,8 +74,7 @@ td_teacher_net_cfg = BertConfig(
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=False
compute_type=mstype.float16
)
td_student_net_cfg = BertConfig(
batch_size=32,
@ -95,6 +94,5 @@ td_student_net_cfg = BertConfig(
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=False
compute_type=mstype.float16
)

View File

@ -25,7 +25,6 @@ from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore import context
from .fused_layer_norm import FusedLayerNorm
class BertConfig:
@ -78,8 +77,7 @@ class BertConfig:
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
@ -98,7 +96,6 @@ class BertConfig:
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
self.enable_fused_layernorm = enable_fused_layernorm
class EmbeddingLookup(nn.Cell):
@ -244,8 +241,7 @@ class BertOutput(nn.Cell):
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
@ -256,11 +252,7 @@ class BertOutput(nn.Cell):
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
self.compute_type = compute_type
else:
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
@ -602,8 +594,7 @@ class BertSelfAttention(nn.Cell):
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
@ -628,8 +619,7 @@ class BertSelfAttention(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
@ -672,8 +662,7 @@ class BertEncoderCell(nn.Cell):
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
enable_fused_layernorm=False):
compute_type=mstype.float32):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
@ -685,8 +674,7 @@ class BertEncoderCell(nn.Cell):
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
@ -695,8 +683,7 @@ class BertEncoderCell(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
def construct(self, hidden_states, attention_mask):
"""bert encoder cell"""
# self-attention
@ -743,8 +730,7 @@ class BertTransformer(nn.Cell):
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False,
enable_fused_layernorm=False):
return_all_encoders=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
layers = []
@ -760,8 +746,7 @@ class BertTransformer(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.reshape = P.Reshape()
@ -877,8 +862,7 @@ class BertModel(nn.Cell):
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True,
enable_fused_layernorm=config.enable_fused_layernorm)
return_all_encoders=True)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
@ -981,8 +965,7 @@ class TinyBertModel(nn.Cell):
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True,
enable_fused_layernorm=config.enable_fused_layernorm)
return_all_encoders=True)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)

View File

@ -82,8 +82,7 @@ def get_config(version='base', batch_size=1):
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=False)
compute_type=mstype.float16)
else:
bert_config = BertConfig(batch_size=batch_size)
return bert_config

View File

@ -82,8 +82,7 @@ def get_config(version='base', batch_size=1):
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=False)
compute_type=mstype.float16)
else:
bert_config = BertConfig(batch_size=batch_size)
return bert_config