forked from mindspore-Ecosystem/mindspore
!10975 Increase performance of fasttext
From: @zhouyaqiang0 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
2f8010f805
|
@ -71,21 +71,35 @@ class FastTextInferCell(nn.Cell):
|
|||
|
||||
return predicted_idx
|
||||
|
||||
def load_infer_dataset(batch_size, datafile):
|
||||
def load_infer_dataset(batch_size, datafile, bucket):
|
||||
"""data loader for infer"""
|
||||
data_set = ds.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
||||
def batch_per_bucket(bucket_length, input_file):
|
||||
input_file = input_file + '/test_dataset_bs_' + str(bucket_length) + '.mindrecord'
|
||||
if not input_file:
|
||||
raise FileNotFoundError("input file parameter must not be empty.")
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_idx")
|
||||
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
|
||||
data_set = ds.MindDataset(input_file,
|
||||
columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_idx")
|
||||
|
||||
data_set = data_set.batch(batch_size, drop_remainder=False)
|
||||
return data_set
|
||||
for i, _ in enumerate(bucket):
|
||||
bucket_len = bucket[i]
|
||||
ds_per = batch_per_bucket(bucket_len, datafile)
|
||||
if i == 0:
|
||||
data_set = ds_per
|
||||
else:
|
||||
data_set = data_set + ds_per
|
||||
|
||||
return data_set
|
||||
|
||||
def run_fasttext_infer():
|
||||
"""run infer with FastText"""
|
||||
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path)
|
||||
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path, bucket=config.test_buckets)
|
||||
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
|
||||
parameter_dict = load_checkpoint(args.model_ckpt)
|
||||
|
@ -107,7 +121,15 @@ def run_fasttext_infer():
|
|||
|
||||
from sklearn.metrics import accuracy_score, classification_report
|
||||
target_sens = np.array(target_sens).flatten()
|
||||
merge_target_sens = []
|
||||
for target_sen in target_sens:
|
||||
merge_target_sens.extend(target_sen)
|
||||
target_sens = merge_target_sens
|
||||
predictions = np.array(predictions).flatten()
|
||||
merge_predictions = []
|
||||
for prediction in predictions:
|
||||
merge_predictions.extend(prediction)
|
||||
predictions = merge_predictions
|
||||
acc = accuracy_score(target_sens, predictions)
|
||||
|
||||
result_report = classification_report(target_sens, predictions, target_names=target_label1)
|
||||
|
|
|
@ -60,7 +60,7 @@ then
|
|||
mkdir ./dbpedia
|
||||
cd ./dbpedia || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [128,512,3013] --test_bucket [1120]
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [64,128,256,512,3013] --test_bucket [64,128,256,512,1120]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
@ -74,7 +74,7 @@ then
|
|||
mkdir ./yelp_p
|
||||
cd ./yelp_p || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [2955]
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [64,128,256,512,2955]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ DATASET=$(get_real_path $1)
|
|||
echo $DATASET
|
||||
DATANAME=$2
|
||||
MODEL_CKPT=$(get_real_path $3)
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=5
|
||||
export RANK_ID=0
|
||||
|
|
|
@ -33,7 +33,6 @@ echo $DATASET
|
|||
DATANAME=$(basename $DATASET)
|
||||
echo $DATANAME
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
|
|
@ -20,11 +20,12 @@ from easydict import EasyDict as ed
|
|||
config_yelpp = ed({
|
||||
'vocab_size': 6414979,
|
||||
'buckets': [64, 128, 256, 512, 2955],
|
||||
'batch_size': 128,
|
||||
'test_buckets': [64, 128, 256, 512, 2955],
|
||||
'batch_size': 2048,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 2,
|
||||
'epoch': 5,
|
||||
'lr': 0.02,
|
||||
'lr': 0.30,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
|
@ -37,12 +38,13 @@ config_yelpp = ed({
|
|||
|
||||
config_db = ed({
|
||||
'vocab_size': 6596536,
|
||||
'buckets': [128, 512, 3013],
|
||||
'batch_size': 128,
|
||||
'buckets': [64, 128, 256, 512, 3013],
|
||||
'test_buckets': [64, 128, 256, 512, 1120],
|
||||
'batch_size': 4096,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 14,
|
||||
'epoch': 5,
|
||||
'lr': 0.05,
|
||||
'lr': 0.8,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
|
@ -56,15 +58,16 @@ config_db = ed({
|
|||
config_ag = ed({
|
||||
'vocab_size': 1383812,
|
||||
'buckets': [64, 128, 467],
|
||||
'batch_size': 128,
|
||||
'test_buckets': [467],
|
||||
'batch_size': 512,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 4,
|
||||
'epoch': 5,
|
||||
'lr': 0.05,
|
||||
'lr': 0.2,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 115,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'poly_lr_scheduler_power': 0.001,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 116,
|
||||
|
|
|
@ -59,9 +59,7 @@ class FastText(nn.Cell):
|
|||
src_tokens = self.embeding_func(src_tokens)
|
||||
embeding = self.reducesum(src_tokens, 1)
|
||||
|
||||
length_tiled = self.tile(src_token_length, (1, self.embeding_dims))
|
||||
|
||||
embeding = self.realdiv(embeding, length_tiled)
|
||||
embeding = self.realdiv(embeding, src_token_length)
|
||||
|
||||
embeding = self.cast(embeding, mstype.float16)
|
||||
classifer = self.fc(embeding)
|
||||
|
|
|
@ -13,9 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText data loader"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
|
||||
|
||||
def load_dataset(dataset_path,
|
||||
|
@ -37,14 +35,10 @@ def load_dataset(dataset_path,
|
|||
shuffle=shuffle,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=8)
|
||||
num_parallel_workers=4)
|
||||
ori_dataset_size = data_set.get_dataset_size()
|
||||
print(f"Dataset size: {ori_dataset_size}")
|
||||
repeat_count = epoch_count
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_idx")
|
||||
|
||||
data_set = data_set.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||
output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
|
||||
|
|
|
@ -136,7 +136,7 @@ def _build_training_pipeline(pre_dataset):
|
|||
loss_monitor = LossCallBack(rank_ids=rank_id)
|
||||
dataset_size = pre_dataset.get_dataset_size()
|
||||
time_monitor = TimeMonitor(data_size=dataset_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps,
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps * config.epoch,
|
||||
keep_checkpoint_max=config.keep_ckpt_max)
|
||||
callbacks = [time_monitor, loss_monitor]
|
||||
if rank_size is None or int(rank_size) == 1:
|
||||
|
|
Loading…
Reference in New Issue