forked from mindspore-Ecosystem/mindspore
md delete set_dataset_size interface
This commit is contained in:
parent
c45f79d36b
commit
d0410d6191
|
@ -3052,14 +3052,6 @@ class MindDataset(MappableDataset):
|
|||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
# manually set dataset_size as a tempoary solution.
|
||||
def set_dataset_size(self, value):
|
||||
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||
if value >= 0:
|
||||
self.dataset_size = value
|
||||
else:
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_option is None:
|
||||
return True
|
||||
|
@ -3503,13 +3495,6 @@ class GeneratorDataset(MappableDataset):
|
|||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
# manually set dataset_size as a temporary solution.
|
||||
def set_dataset_size(self, value):
|
||||
if value >= 0:
|
||||
self.dataset_size = value
|
||||
else:
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
if id(self) in memodict:
|
||||
return memodict[id(self)]
|
||||
|
@ -3696,14 +3681,6 @@ class TFRecordDataset(SourceDataset):
|
|||
self.dataset_size = self.num_samples
|
||||
return self.dataset_size
|
||||
|
||||
# manually set dataset_size as a tempoary solution.
|
||||
def set_dataset_size(self, value):
|
||||
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||
if value >= 0:
|
||||
self.dataset_size = value
|
||||
else:
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
||||
|
|
|
@ -141,7 +141,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank
|
|||
dataset = TxtDataset(root, data_dir)
|
||||
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
||||
de_dataset.set_dataset_size(len(sampler))
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers,
|
||||
operations=transform_img)
|
||||
|
|
|
@ -156,7 +156,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_s
|
|||
dataset = TxtDataset(root, data_dir)
|
||||
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
||||
de_dataset.set_dataset_size(len(sampler))
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
|
||||
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
|
||||
|
|
|
@ -81,7 +81,6 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
|
|||
|
||||
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
|
||||
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||
ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
|
||||
image_trans = [
|
||||
vc.Rescale(1.0 / 255.0, 0.0),
|
||||
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
|
||||
|
|
|
@ -173,7 +173,6 @@ def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000):
|
|||
yield train_eval_gen.__next__()
|
||||
|
||||
ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"])
|
||||
ds.set_dataset_size(numbers_of_batch)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
|
|
@ -165,7 +165,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
|
|||
yield train_eval_gen.__next__()
|
||||
|
||||
ds = de.GeneratorDataset(_iter_h5_data(), ["ids", "weights", "labels"])
|
||||
ds.set_dataset_size(numbers_of_batch)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
|
|
@ -161,7 +161,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
|
|||
|
||||
ds = de.GeneratorDataset(_iter_h5_data(),
|
||||
["ids", "weights", "labels"])
|
||||
ds.set_dataset_size(numbers_of_batch)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
|
|
@ -23,8 +23,7 @@ from mindspore import log as logger
|
|||
from .config import bert_net_cfg
|
||||
|
||||
|
||||
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
|
||||
data_sink_steps=1, data_dir=None, schema_dir=None):
|
||||
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = epoch_size
|
||||
|
@ -40,10 +39,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
new_size = ori_dataset_size
|
||||
if enable_data_sink == "true":
|
||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||
ds.set_dataset_size(new_size)
|
||||
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
|
|
|
@ -94,11 +94,9 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shu
|
|||
"""
|
||||
# create iter dataset
|
||||
dataset = HwVocRawDataset(data_url, usage=usage)
|
||||
dataset_len = len(dataset)
|
||||
|
||||
# wrapped with GeneratorDataset
|
||||
dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
|
||||
dataset.set_dataset_size(dataset_len)
|
||||
dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
|
||||
|
||||
channelswap_op = C.HWC2CHW()
|
||||
|
|
|
@ -262,9 +262,6 @@ def test_concat_12():
|
|||
data1 = ds.GeneratorDataset(generator, ["col1"])
|
||||
data2 = ds.GeneratorDataset(generator_10, ["col1"])
|
||||
|
||||
data1.set_dataset_size(3)
|
||||
data2.set_dataset_size(7)
|
||||
|
||||
data3 = data1 + data2
|
||||
res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1]
|
||||
|
||||
|
@ -288,9 +285,6 @@ def test_concat_13():
|
|||
data1 = ds.GeneratorDataset(generator, ["col1"])
|
||||
data2 = ds.GeneratorDataset(generator_20, ["col1"])
|
||||
|
||||
data1.set_dataset_size(3)
|
||||
data2.set_dataset_size(10)
|
||||
|
||||
data1 = data1.batch(3)
|
||||
data2 = data2.batch(5)
|
||||
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -161,18 +159,6 @@ def test_imagefolder():
|
|||
assert data.num_classes() == 4
|
||||
|
||||
|
||||
def test_generator():
|
||||
def generator():
|
||||
for i in range(64):
|
||||
yield (np.array([i]),)
|
||||
|
||||
data1 = ds.GeneratorDataset(generator, ["data"])
|
||||
data1.set_dataset_size(10)
|
||||
assert data1.get_dataset_size() == 10
|
||||
data1.output_shapes()
|
||||
assert data1.get_dataset_size() == 10
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_compare_v1_and_2()
|
||||
# test_imagefolder()
|
||||
|
|
Loading…
Reference in New Issue