Started vision decoupling

map validate params

Changing bindings

Change behavior of vision.cc

Added signature and changed Cmake

Compiling

Need to fix compile

Added compiling random transforms

python changes for decoupling

two failed test case remaining

Compiling random choice

passes all ut

Changed assert

review comments

Added validate params

add back return value

Fix lint

py lint fix

pylint 2

Addressing comments
This commit is contained in:
Eric 2020-12-03 22:55:03 -05:00
parent 8938c2f5ee
commit be46ccf721
12 changed files with 249 additions and 68 deletions

View File

@ -2,25 +2,27 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(APItoPython OBJECT
python/pybind_register.cc
python/pybind_conversion.cc
python/bindings/dataset/callback/bindings.cc
python/bindings/dataset/core/bindings.cc
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc
python/bindings/dataset/engine/datasetops/source/bindings.cc
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
python/bindings/dataset/engine/gnn/bindings.cc
python/bindings/dataset/include/datasets_bindings.cc
python/bindings/dataset/include/iterator_bindings.cc
python/bindings/dataset/include/execute_binding.cc
python/bindings/dataset/include/schema_bindings.cc
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/core/bindings.cc
python/bindings/dataset/callback/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc
python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc
python/bindings/dataset/engine/datasetops/source/bindings.cc
python/bindings/dataset/engine/gnn/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc
python/bindings/dataset/kernels/image/bindings.cc
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
python/bindings/dataset/kernels/ir/bindings.cc
python/bindings/dataset/kernels/ir/image/bindings.cc
python/bindings/dataset/text/bindings.cc
python/bindings/dataset/text/kernels/bindings.cc
python/bindings/mindrecord/include/bindings.cc
python/pybind_conversion.cc
python/pybind_register.cc
)
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
endif ()

View File

@ -49,42 +49,5 @@ PYBIND_REGISTER(TensorOp, 0, ([](const py::module *m) {
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
}));
PYBIND_REGISTER(ComposeOp, 1, ([](const py::module *m) {
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
return std::make_shared<ComposeOp>(t_ops);
}));
}));
PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) {
(void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(
*m, "NoOp", "TensorOp that does nothing, for testing purposes only.")
.def(py::init<>());
}));
PYBIND_REGISTER(RandomChoiceOp, 1, ([](const py::module *m) {
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
return std::make_shared<RandomChoiceOp>(t_ops);
}));
}));
PYBIND_REGISTER(RandomApplyOp, 1, ([](const py::module *m) {
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
.def(py::init([](double prob, const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
}
return std::make_shared<RandomApplyOp>(prob, t_ops);
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,91 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/kernels/py_func_op.h"
namespace mindspore {
namespace dataset {
Status PyListToTensorOperations(const py::list &py_ops, std::vector<std::shared_ptr<TensorOperation>> *ops) {
RETURN_UNEXPECTED_IF_NULL(ops);
for (auto op : py_ops) {
if (py::isinstance<TensorOp>(op)) {
ops->emplace_back(std::make_shared<transforms::PreBuiltOperation>(op.cast<std::shared_ptr<TensorOp>>()));
} else if (py::isinstance<py::function>(op)) {
ops->emplace_back(
std::make_shared<transforms::PreBuiltOperation>(std::make_shared<PyFuncOp>(op.cast<py::function>())));
} else if (py::isinstance<TensorOperation>(op)) {
ops->emplace_back(op.cast<std::shared_ptr<TensorOperation>>());
} else {
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp, TensorOperation nor a pyfunc.");
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
for (auto const &op : *ops) {
RETURN_UNEXPECTED_IF_NULL(op);
}
return Status::OK();
}
PYBIND_REGISTER(TensorOperation, 0, ([](const py::module *m) {
(void)py::class_<TensorOperation, std::shared_ptr<TensorOperation>>(*m, "TensorOperation");
py::arg("TensorOperation");
}));
PYBIND_REGISTER(
ComposeOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::ComposeOperation, TensorOperation, std::shared_ptr<transforms::ComposeOperation>>(
*m, "ComposeOperation")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOperation>> t_ops;
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
auto compose = std::make_shared<transforms::ComposeOperation>(std::move(t_ops));
THROW_IF_ERROR(compose->ValidateParams());
return compose;
}));
}));
PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::RandomChoiceOperation, TensorOperation,
std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOperation>> t_ops;
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
auto random_choice = std::make_shared<transforms::RandomChoiceOperation>(std::move(t_ops));
THROW_IF_ERROR(random_choice->ValidateParams());
return random_choice;
}));
}));
PYBIND_REGISTER(RandomApplyOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::RandomApplyOperation, TensorOperation,
std::shared_ptr<transforms::RandomApplyOperation>>(*m, "RandomApplyOperation")
.def(py::init([](double prob, const py::list &ops) {
std::vector<std::shared_ptr<TensorOperation>> t_ops;
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
auto random_apply = std::make_shared<transforms::RandomApplyOperation>(std::move(t_ops), prob);
THROW_IF_ERROR(random_apply->ValidateParams());
return random_apply;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/include/vision_lite.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
RandomSelectSubpolicyOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation,
std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation")
.def(py::init([](const py::list &py_policy) {
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy;
for (auto &py_sub : py_policy) {
cpp_policy.push_back({});
for (auto handle : py_sub.cast<py::list>()) {
py::tuple tp = handle.cast<py::tuple>();
if (tp.is_none() || tp.size() != 2) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
}
std::shared_ptr<TensorOperation> t_op;
if (py::isinstance<TensorOperation>(tp[0])) {
t_op = (tp[0]).cast<std::shared_ptr<TensorOperation>>();
} else if (py::isinstance<TensorOp>(tp[0])) {
t_op = std::make_shared<transforms::PreBuiltOperation>((tp[0]).cast<std::shared_ptr<TensorOp>>());
} else if (py::isinstance<py::function>(tp[0])) {
t_op = std::make_shared<transforms::PreBuiltOperation>(
std::make_shared<PyFuncOp>((tp[0]).cast<py::function>()));
} else {
THROW_IF_ERROR(
Status(StatusCode::kUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
}
double prob = (tp[1]).cast<py::float_>();
if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
}
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
}
}
return std::make_shared<vision::RandomSelectSubpolicyOperation>(cpp_policy);
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -92,13 +92,20 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operat
std::shared_ptr<TensorOp> tensor_op;
if (py::isinstance<TensorOp>(op)) {
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
} else if (py::isinstance<py::function>(op)) {
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
} else {
THROW_IF_ERROR(
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
if (py::isinstance<TensorOperation>(op)) {
vector.push_back(op.cast<std::shared_ptr<TensorOperation>>());
} else {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED(
"Error: tensor_op is not recognised (not TensorOp, TensorOperation and not pyfunc).");
}());
}
}
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
}
}
return vector;
@ -107,12 +114,15 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operat
std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation) {
std::shared_ptr<TensorOperation> op;
std::shared_ptr<TensorOp> tensor_op;
if (py::isinstance<TensorOp>(operation)) {
if (py::isinstance<TensorOperation>(operation)) {
op = operation.cast<std::shared_ptr<TensorOperation>>();
} else if (py::isinstance<TensorOp>(operation)) {
tensor_op = operation.cast<std::shared_ptr<TensorOp>>();
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
} else {
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op."); }());
THROW_IF_ERROR(
[]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op or TensorOperation."); }());
}
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
return op;
}

View File

@ -88,7 +88,9 @@ Status MapNode::ValidateParams() {
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (const auto &op : operations_) {
RETURN_IF_NOT_OK(op->ValidateParams());
}
if (!input_columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_));
}

View File

@ -381,5 +381,5 @@ def check_gnn_list_or_ndarray(param, param_name):
def check_tensor_op(param, param_name):
"""check whether param is a tensor op or a callable Python function"""
if not isinstance(param, cde.TensorOp) and not callable(param):
raise TypeError("{0} is neither a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))

View File

@ -2320,10 +2320,16 @@ class MapDataset(Dataset):
def parse(self, children=None):
column_order = replace_none(self.column_order, [])
operations = []
for op in self.operations:
if op and getattr(op, 'parse', None):
operations.append(op.parse())
else:
operations.append(op)
cc = self.cache.cache_client if self.cache else None
callbacks = [cb.create_runtime_obj() for cb in self.callbacks] if self.callbacks else []
return cde.MapNode(children[0], self.operations, self.input_columns, self.output_columns, column_order, cc,
return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, column_order, cc,
callbacks).SetNumWorkers(self.num_parallel_workers)
def get_args(self):

View File

@ -327,7 +327,7 @@ class Unique(cde.UniqueOp):
>>> # +---------+-----------------+---------+
"""
class Compose(cde.ComposeOp):
class Compose():
"""
Compose a list of transforms into a single transform.
@ -344,10 +344,18 @@ class Compose(cde.ComposeOp):
@check_random_transform_ops
def __init__(self, transforms):
super().__init__(transforms)
self.transforms = transforms
def parse(self):
operations = []
for op in self.transforms:
if op and getattr(op, 'parse', None):
operations.append(op.parse())
else:
operations.append(op)
return cde.ComposeOperation(operations)
class RandomApply(cde.RandomApplyOp):
class RandomApply():
"""
Randomly perform a series of transforms with a given probability.
@ -365,10 +373,20 @@ class RandomApply(cde.RandomApplyOp):
@check_random_transform_ops
def __init__(self, transforms, prob=0.5):
super().__init__(prob, transforms)
self.transforms = transforms
self.prob = prob
def parse(self):
operations = []
for op in self.transforms:
if op and getattr(op, 'parse', None):
operations.append(op.parse())
else:
operations.append(op)
return cde.RandomApplyOperation(self.prob, operations)
class RandomChoice(cde.RandomChoiceOp):
class RandomChoice():
"""
Randomly selects one transform from a list of transforms to perform operation.
@ -385,4 +403,13 @@ class RandomChoice(cde.RandomChoiceOp):
@check_random_transform_ops
def __init__(self, transforms):
super().__init__(transforms)
self.transforms = transforms
def parse(self):
operations = []
for op in self.transforms:
if op and getattr(op, 'parse', None):
operations.append(op.parse())
else:
operations.append(op)
return cde.RandomChoiceOperation(operations)

View File

@ -1322,14 +1322,14 @@ class UniformAugment(cde.UniformAugOp):
super().__init__(transforms, num_ops)
class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
class RandomSelectSubpolicy():
"""
Choose a random sub-policy from a list to be applied on the input image. A sub-policy is a list of tuples
(op, prob), where op is a TensorOp operation and prob is the probability that this op will be applied. Once
a sub-policy is selected, each op within the subpolicy with be applied in sequence according to its probability.
Args:
policy (list(list(tuple(TensorOp,float))): List of sub-policies to choose from.
policy (list(list(tuple(TensorOp, float))): List of sub-policies to choose from.
Examples:
>>> policy = [[(c_vision.RandomRotation((45, 45)), 0.5),
@ -1343,7 +1343,22 @@ class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
@check_random_select_subpolicy_op
def __init__(self, policy):
super().__init__(policy)
self.policy = policy
def parse(self):
"""
Return a C++ representation of the operator for execution
"""
policy = []
for list_one in self.policy:
policy_one = []
for list_two in list_one:
if hasattr(list_two[0], 'parse'):
policy_one.append((list_two[0].parse(), list_two[1]))
else:
policy_one.append((list_two[0], list_two[1]))
policy.append(policy_one)
return cde.RandomSelectSubpolicyOperation(policy)
class SoftDvppDecodeResizeJpeg(cde.SoftDvppDecodeResizeJpegOp):

View File

@ -63,7 +63,7 @@ def test_compose():
# Test exceptions.
with pytest.raises(TypeError) as error_info:
c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
assert "op_list[0] is neither a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value)
assert "op_list[0] is neither a c_transform op (TensorOperation) nor a callable pyfunc." in str(error_info.value)
# Test empty op list
with pytest.raises(ValueError) as error_info:

View File

@ -41,7 +41,7 @@ def test_random_select_subpolicy():
# test exceptions
assert "policy can not be empty." in test_config([[1, 2, 3]], [])
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOp) nor a callable pyfunc" \
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOperation) nor a callable pyfunc" \
in test_config([[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])