forked from mindspore-Ecosystem/mindspore
93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
Data operations, will be used in train.py and eval.py
|
|
"""
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
from imdb import ImdbParser
|
|
import mindspore.dataset as ds
|
|
from mindspore.mindrecord import FileWriter
|
|
|
|
|
|
def create_dataset(data_home, batch_size, repeat_num=1, training=True):
|
|
"""Data operations."""
|
|
ds.config.set_seed(1)
|
|
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0")
|
|
if not training:
|
|
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0")
|
|
|
|
data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4)
|
|
|
|
# apply map operations on images
|
|
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
|
|
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
|
|
data_set = data_set.repeat(count=repeat_num)
|
|
|
|
return data_set
|
|
|
|
|
|
def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True):
|
|
"""
|
|
convert imdb dataset to mindrecoed dataset
|
|
"""
|
|
if weight_np is not None:
|
|
np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np)
|
|
|
|
# write mindrecord
|
|
schema_json = {"id": {"type": "int32"},
|
|
"label": {"type": "int32"},
|
|
"feature": {"type": "int32", "shape": [-1]}}
|
|
|
|
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord")
|
|
if not training:
|
|
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord")
|
|
|
|
def get_imdb_data(features, labels):
|
|
data_list = []
|
|
for i, (label, feature) in enumerate(zip(labels, features)):
|
|
data_json = {"id": i,
|
|
"label": int(label),
|
|
"feature": feature.reshape(-1)}
|
|
data_list.append(data_json)
|
|
return data_list
|
|
|
|
writer = FileWriter(data_dir, shard_num=4)
|
|
data = get_imdb_data(features, labels)
|
|
writer.add_schema(schema_json, "nlp_schema")
|
|
writer.add_index(["id", "label"])
|
|
writer.write_raw_data(data)
|
|
writer.commit()
|
|
|
|
|
|
def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path):
|
|
"""
|
|
convert imdb dataset to mindrecoed dataset
|
|
"""
|
|
parser = ImdbParser(aclimdb_path, glove_path, embed_size)
|
|
parser.parse()
|
|
|
|
if not os.path.exists(preprocess_path):
|
|
print(f"preprocess path {preprocess_path} is not exist")
|
|
os.makedirs(preprocess_path)
|
|
|
|
train_features, train_labels, train_weight_np = parser.get_datas('train')
|
|
_convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np)
|
|
|
|
test_features, test_labels, _ = parser.get_datas('test')
|
|
_convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False)
|