!5648 GetColumnNames for Python
Merge pull request !5648 from MahdiRahmaniHanzaki/get-col-name
This commit is contained in:
commit
d76ac7c6e8
|
@ -44,7 +44,14 @@ PYBIND_REGISTER(
|
|||
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
|
||||
.def("SetBatchParameters",
|
||||
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
|
||||
.def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
|
||||
.def("GetColumnNames",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetColumnNames(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetNextAsMap",
|
||||
[](DEPipeline &de) {
|
||||
py::dict out;
|
||||
|
|
|
@ -172,9 +172,11 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
|
|||
// Function to assign the node as root.
|
||||
Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
|
||||
|
||||
// Function to prepare the tree
|
||||
Status DEPipeline::PrepareTree(const int32_t num_epochs) { return tree_->Prepare(num_epochs); }
|
||||
|
||||
// Function to launch the tree execution.
|
||||
Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) {
|
||||
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
|
||||
Status DEPipeline::LaunchTreeExec() {
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
iterator_ = std::make_unique<DatasetIterator>(tree_);
|
||||
if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
|
||||
|
@ -189,6 +191,25 @@ void DEPipeline::PrintTree() {
|
|||
}
|
||||
}
|
||||
|
||||
Status DEPipeline::GetColumnNames(py::list *output) {
|
||||
if (!tree_->isPrepared()) {
|
||||
RETURN_STATUS_UNEXPECTED("GetColumnNames: Make sure to call prepare before calling GetColumnNames.");
|
||||
}
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map = tree_->root()->column_name_id_map();
|
||||
if (column_name_id_map.empty())
|
||||
RETURN_STATUS_UNEXPECTED("GetColumnNames: Column names was empty. Make sure Prepare is called.");
|
||||
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(),
|
||||
column_name_id_map.end());
|
||||
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
|
||||
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (auto item : column_name_id_vector) {
|
||||
(*output).append(item.first);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::GetNextAsMap(py::dict *output) {
|
||||
TensorMap row;
|
||||
Status s;
|
||||
|
|
|
@ -92,8 +92,14 @@ class DEPipeline {
|
|||
// Function to assign the node as root.
|
||||
Status AssignRootNode(const DsOpPtr &dataset_op);
|
||||
|
||||
// Function to get the column names in the last node in the tree in order
|
||||
Status GetColumnNames(py::list *output);
|
||||
|
||||
// Function to prepare the tree for execution
|
||||
Status PrepareTree(const int32_t num_epochs);
|
||||
|
||||
// Function to launch the tree execution.
|
||||
Status LaunchTreeExec(int32_t num_epochs);
|
||||
Status LaunchTreeExec();
|
||||
|
||||
// Get a row of data as dictionary of column name to the value.
|
||||
Status GetNextAsMap(py::dict *output);
|
||||
|
|
|
@ -83,7 +83,8 @@ void GeneratorOp::Dealloc() noexcept {
|
|||
PyGILState_STATE gstate;
|
||||
gstate = PyGILState_Ensure();
|
||||
// GC the generator object within GIL
|
||||
(void)generator_.dec_ref();
|
||||
if (generator_function_.ref_count() == 1) generator_function_.dec_ref();
|
||||
if (generator_.ref_count() == 1) (void)generator_.dec_ref();
|
||||
// Release GIL
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
|
|
|
@ -211,6 +211,13 @@ class ExecutionTree {
|
|||
// @return Bool - true is ExecutionTree is finished
|
||||
bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; }
|
||||
|
||||
// Return if the ExecutionTree is ready.
|
||||
// @return Bool - true is ExecutionTree is ready
|
||||
bool isPrepared() const {
|
||||
return tree_state_ == TreeState::kDeTStateReady || tree_state_ == kDeTStateExecuting ||
|
||||
tree_state_ == kDeTStateFinished;
|
||||
}
|
||||
|
||||
// Set the ExecutionTree to Finished state.
|
||||
void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; }
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ from mindspore._c_expression import typing
|
|||
|
||||
from mindspore import log as logger
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, check_numpyslicesdataset, check_device_send, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
|
@ -1203,6 +1203,12 @@ class Dataset:
|
|||
self._repeat_count = device_iter.get_repeat_count()
|
||||
device_iter.stop()
|
||||
|
||||
def get_col_names(self):
|
||||
"""
|
||||
Get names of the columns in the dataset
|
||||
"""
|
||||
return Iterator(self).get_col_names()
|
||||
|
||||
def output_shapes(self):
|
||||
"""
|
||||
Get the shapes of output data.
|
||||
|
|
|
@ -93,7 +93,7 @@ class Iterator:
|
|||
|
||||
root = self.__convert_node_postorder(self.dataset)
|
||||
self.depipeline.AssignRootNode(root)
|
||||
self.depipeline.LaunchTreeExec(self.num_epochs)
|
||||
self.depipeline.PrepareTree(self.num_epochs)
|
||||
self._index = 0
|
||||
|
||||
def stop(self):
|
||||
|
@ -276,6 +276,9 @@ class Iterator:
|
|||
def num_classes(self):
|
||||
return self.depipeline.GetNumClasses()
|
||||
|
||||
def get_col_names(self):
|
||||
return self.depipeline.GetColumnNames()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
|
@ -283,6 +286,10 @@ class SaveOp(Iterator):
|
|||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def __init__(self, dataset, num_epochs=-1):
|
||||
super().__init__(dataset, num_epochs)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def get_next(self):
|
||||
pass
|
||||
|
||||
|
@ -298,6 +305,10 @@ class DictIterator(Iterator):
|
|||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def __init__(self, dataset, num_epochs=-1):
|
||||
super().__init__(dataset, num_epochs)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
|
@ -328,6 +339,7 @@ class TupleIterator(Iterator):
|
|||
columns = [columns]
|
||||
dataset = dataset.project(columns)
|
||||
super().__init__(dataset, num_epochs)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
||||
CELEBA_DIR = "../data/dataset/testCelebAData"
|
||||
CIFAR10_DIR = "../data/dataset/testCifar10Data"
|
||||
CIFAR100_DIR = "../data/dataset/testCifar100Data"
|
||||
CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json"
|
||||
COCO_DIR = "../data/dataset/testCOCO/train"
|
||||
COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json"
|
||||
CSV_DIR = "../data/dataset/testCSV/1.csv"
|
||||
IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/"
|
||||
MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest"
|
||||
MNIST_DIR = "../data/dataset/testMnistData"
|
||||
TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
VOC_DIR = "../data/dataset/testVOC2012"
|
||||
|
||||
|
||||
def test_get_column_name_celeba():
|
||||
data = ds.CelebADataset(CELEBA_DIR)
|
||||
assert data.get_col_names() == ["image", "attr"]
|
||||
|
||||
|
||||
def test_get_column_name_cifar10():
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_get_column_name_cifar100():
|
||||
data = ds.Cifar100Dataset(CIFAR100_DIR)
|
||||
assert data.get_col_names() == ["image", "coarse_label", "fine_label"]
|
||||
|
||||
|
||||
def test_get_column_name_clue():
|
||||
data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train")
|
||||
assert data.get_col_names() == ["label", "sentence1", "sentence2"]
|
||||
|
||||
|
||||
def test_get_column_name_coco():
|
||||
data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection",
|
||||
decode=True, shuffle=False)
|
||||
assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"]
|
||||
|
||||
|
||||
def test_get_column_name_csv():
|
||||
data = ds.CSVDataset(CSV_DIR)
|
||||
assert data.get_col_names() == ["1", "2", "3", "4"]
|
||||
data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"])
|
||||
assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
|
||||
|
||||
|
||||
def test_get_column_name_generator():
|
||||
def generator():
|
||||
for i in range(64):
|
||||
yield (np.array([i]),)
|
||||
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
assert data.get_col_names() == ["data"]
|
||||
|
||||
|
||||
def test_get_column_name_imagefolder():
|
||||
data = ds.ImageFolderDatasetV2(IMAGE_FOLDER_DIR)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_get_column_name_iterator():
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
itr = data.create_tuple_iterator(num_epochs=1)
|
||||
assert itr.get_col_names() == ["image", "label"]
|
||||
itr = data.create_dict_iterator(num_epochs=1)
|
||||
assert itr.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_get_column_name_manifest():
|
||||
data = ds.ManifestDataset(MANIFEST_DIR)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_get_column_name_map():
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
center_crop_op = vision.CenterCrop(10)
|
||||
data = data.map(input_columns=["image"], operations=center_crop_op)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["image"])
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1"])
|
||||
assert data.get_col_names() == ["col1", "label"]
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1", "col2"],
|
||||
columns_order=["col2", "col1"])
|
||||
assert data.get_col_names() == ["col2", "col1"]
|
||||
|
||||
|
||||
def test_get_column_name_mnist():
|
||||
data = ds.MnistDataset(MNIST_DIR)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_get_column_name_numpy_slices():
|
||||
np_data = {"a": [1, 2], "b": [3, 4]}
|
||||
data = ds.NumpySlicesDataset(np_data, shuffle=False)
|
||||
assert data.get_col_names() == ["a", "b"]
|
||||
data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False)
|
||||
assert data.get_col_names() == ["column_0"]
|
||||
|
||||
|
||||
def test_get_column_name_tfrecord():
|
||||
data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA)
|
||||
assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
|
||||
"col_sint64"]
|
||||
data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA,
|
||||
columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"])
|
||||
assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"]
|
||||
|
||||
data = ds.TFRecordDataset(TFRECORD_DIR)
|
||||
assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
|
||||
"col_sint64", "col_sint8"]
|
||||
s = ds.Schema()
|
||||
s.add_column("line", "string", [])
|
||||
s.add_column("words", "string", [-1])
|
||||
s.add_column("chinese", "string", [])
|
||||
|
||||
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
|
||||
assert data.get_col_names() == ["line", "words", "chinese"]
|
||||
|
||||
|
||||
def test_get_column_name_to_device():
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
data = data.to_device()
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
|
||||
|
||||
def test_get_column_name_voc():
|
||||
data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
assert data.get_col_names() == ["image", "target"]
|
||||
|
||||
|
||||
def test_get_column_name_project():
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
data = data.project(columns=["image"])
|
||||
assert data.get_col_names() == ["image"]
|
||||
|
||||
|
||||
def test_get_column_name_rename():
|
||||
data = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
assert data.get_col_names() == ["image", "label"]
|
||||
data = data.rename(["image", "label"], ["test1", "test2"])
|
||||
assert data.get_col_names() == ["test1", "test2"]
|
||||
|
||||
|
||||
def test_get_column_name_zip():
|
||||
data1 = ds.Cifar10Dataset(CIFAR10_DIR)
|
||||
assert data1.get_col_names() == ["image", "label"]
|
||||
data2 = ds.CSVDataset(CSV_DIR)
|
||||
assert data2.get_col_names() == ["1", "2", "3", "4"]
|
||||
data = ds.zip((data1, data2))
|
||||
assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_column_name_celeba()
|
||||
test_get_column_name_cifar10()
|
||||
test_get_column_name_cifar100()
|
||||
test_get_column_name_clue()
|
||||
test_get_column_name_coco()
|
||||
test_get_column_name_csv()
|
||||
test_get_column_name_generator()
|
||||
test_get_column_name_imagefolder()
|
||||
test_get_column_name_iterator()
|
||||
test_get_column_name_manifest()
|
||||
test_get_column_name_map()
|
||||
test_get_column_name_mnist()
|
||||
test_get_column_name_numpy_slices()
|
||||
test_get_column_name_tfrecord()
|
||||
test_get_column_name_to_device()
|
||||
test_get_column_name_voc()
|
||||
test_get_column_name_project()
|
||||
test_get_column_name_rename()
|
||||
test_get_column_name_zip()
|
Loading…
Reference in New Issue