forked from mindspore-Ecosystem/mindspore
!10988 Tensor op decoupling stage 1
From: @ezphlow Reviewed-by: Signed-off-by:
This commit is contained in:
commit
7957a3b6f5
|
@ -2,25 +2,27 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(APItoPython OBJECT
|
||||
python/pybind_register.cc
|
||||
python/pybind_conversion.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
||||
python/bindings/dataset/engine/gnn/bindings.cc
|
||||
python/bindings/dataset/include/datasets_bindings.cc
|
||||
python/bindings/dataset/include/iterator_bindings.cc
|
||||
python/bindings/dataset/include/execute_binding.cc
|
||||
python/bindings/dataset/include/schema_bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
python/bindings/dataset/kernels/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||
python/bindings/dataset/engine/gnn/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
python/bindings/dataset/kernels/image/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
||||
python/bindings/dataset/kernels/ir/bindings.cc
|
||||
python/bindings/dataset/kernels/ir/image/bindings.cc
|
||||
python/bindings/dataset/text/bindings.cc
|
||||
python/bindings/dataset/text/kernels/bindings.cc
|
||||
python/bindings/mindrecord/include/bindings.cc
|
||||
python/pybind_conversion.cc
|
||||
python/pybind_register.cc
|
||||
)
|
||||
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
endif ()
|
||||
|
|
|
@ -49,42 +49,5 @@ PYBIND_REGISTER(TensorOp, 0, ([](const py::module *m) {
|
|||
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ComposeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
return std::make_shared<ComposeOp>(t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(
|
||||
*m, "NoOp", "TensorOp that does nothing, for testing purposes only.")
|
||||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomChoiceOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
return std::make_shared<RandomChoiceOp>(t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomApplyOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
|
||||
.def(py::init([](double prob, const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
if (prob < 0 || prob > 1) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
|
||||
}
|
||||
return std::make_shared<RandomApplyOp>(prob, t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PyListToTensorOperations(const py::list &py_ops, std::vector<std::shared_ptr<TensorOperation>> *ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ops);
|
||||
for (auto op : py_ops) {
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
ops->emplace_back(std::make_shared<transforms::PreBuiltOperation>(op.cast<std::shared_ptr<TensorOp>>()));
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
ops->emplace_back(
|
||||
std::make_shared<transforms::PreBuiltOperation>(std::make_shared<PyFuncOp>(op.cast<py::function>())));
|
||||
} else if (py::isinstance<TensorOperation>(op)) {
|
||||
ops->emplace_back(op.cast<std::shared_ptr<TensorOperation>>());
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp, TensorOperation nor a pyfunc.");
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
|
||||
for (auto const &op : *ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PYBIND_REGISTER(TensorOperation, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TensorOperation, std::shared_ptr<TensorOperation>>(*m, "TensorOperation");
|
||||
py::arg("TensorOperation");
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ComposeOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::ComposeOperation, TensorOperation, std::shared_ptr<transforms::ComposeOperation>>(
|
||||
*m, "ComposeOperation")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
|
||||
auto compose = std::make_shared<transforms::ComposeOperation>(std::move(t_ops));
|
||||
THROW_IF_ERROR(compose->ValidateParams());
|
||||
return compose;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::RandomChoiceOperation, TensorOperation,
|
||||
std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
|
||||
auto random_choice = std::make_shared<transforms::RandomChoiceOperation>(std::move(t_ops));
|
||||
THROW_IF_ERROR(random_choice->ValidateParams());
|
||||
return random_choice;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomApplyOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::RandomApplyOperation, TensorOperation,
|
||||
std::shared_ptr<transforms::RandomApplyOperation>>(*m, "RandomApplyOperation")
|
||||
.def(py::init([](double prob, const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
|
||||
auto random_apply = std::make_shared<transforms::RandomApplyOperation>(std::move(t_ops), prob);
|
||||
THROW_IF_ERROR(random_apply->ValidateParams());
|
||||
return random_apply;
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/include/vision_lite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomSelectSubpolicyOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation,
|
||||
std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation")
|
||||
.def(py::init([](const py::list &py_policy) {
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy;
|
||||
for (auto &py_sub : py_policy) {
|
||||
cpp_policy.push_back({});
|
||||
for (auto handle : py_sub.cast<py::list>()) {
|
||||
py::tuple tp = handle.cast<py::tuple>();
|
||||
if (tp.is_none() || tp.size() != 2) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
|
||||
}
|
||||
std::shared_ptr<TensorOperation> t_op;
|
||||
if (py::isinstance<TensorOperation>(tp[0])) {
|
||||
t_op = (tp[0]).cast<std::shared_ptr<TensorOperation>>();
|
||||
} else if (py::isinstance<TensorOp>(tp[0])) {
|
||||
t_op = std::make_shared<transforms::PreBuiltOperation>((tp[0]).cast<std::shared_ptr<TensorOp>>());
|
||||
} else if (py::isinstance<py::function>(tp[0])) {
|
||||
t_op = std::make_shared<transforms::PreBuiltOperation>(
|
||||
std::make_shared<PyFuncOp>((tp[0]).cast<py::function>()));
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
Status(StatusCode::kUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
|
||||
}
|
||||
double prob = (tp[1]).cast<py::float_>();
|
||||
if (prob < 0 || prob > 1) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
|
||||
}
|
||||
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
|
||||
}
|
||||
}
|
||||
return std::make_shared<vision::RandomSelectSubpolicyOperation>(cpp_policy);
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -92,13 +92,20 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operat
|
|||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
|
||||
if (py::isinstance<TensorOperation>(op)) {
|
||||
vector.push_back(op.cast<std::shared_ptr<TensorOperation>>());
|
||||
} else {
|
||||
THROW_IF_ERROR([]() {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Error: tensor_op is not recognised (not TensorOp, TensorOperation and not pyfunc).");
|
||||
}());
|
||||
}
|
||||
}
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
|
@ -107,12 +114,15 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operat
|
|||
std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation) {
|
||||
std::shared_ptr<TensorOperation> op;
|
||||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(operation)) {
|
||||
if (py::isinstance<TensorOperation>(operation)) {
|
||||
op = operation.cast<std::shared_ptr<TensorOperation>>();
|
||||
} else if (py::isinstance<TensorOp>(operation)) {
|
||||
tensor_op = operation.cast<std::shared_ptr<TensorOp>>();
|
||||
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op."); }());
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op or TensorOperation."); }());
|
||||
}
|
||||
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
|
||||
return op;
|
||||
}
|
||||
|
||||
|
|
|
@ -88,7 +88,9 @@ Status MapNode::ValidateParams() {
|
|||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
for (const auto &op : operations_) {
|
||||
RETURN_IF_NOT_OK(op->ValidateParams());
|
||||
}
|
||||
if (!input_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_));
|
||||
}
|
||||
|
|
|
@ -381,5 +381,5 @@ def check_gnn_list_or_ndarray(param, param_name):
|
|||
|
||||
def check_tensor_op(param, param_name):
|
||||
"""check whether param is a tensor op or a callable Python function"""
|
||||
if not isinstance(param, cde.TensorOp) and not callable(param):
|
||||
raise TypeError("{0} is neither a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))
|
||||
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
||||
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
||||
|
|
|
@ -2332,10 +2332,16 @@ class MapDataset(Dataset):
|
|||
|
||||
def parse(self, children=None):
|
||||
column_order = replace_none(self.column_order, [])
|
||||
operations = []
|
||||
for op in self.operations:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
|
||||
cc = self.cache.cache_client if self.cache else None
|
||||
callbacks = [cb.create_runtime_obj() for cb in self.callbacks] if self.callbacks else []
|
||||
return cde.MapNode(children[0], self.operations, self.input_columns, self.output_columns, column_order, cc,
|
||||
return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, column_order, cc,
|
||||
callbacks).SetNumWorkers(self.num_parallel_workers)
|
||||
|
||||
def get_args(self):
|
||||
|
|
|
@ -327,7 +327,7 @@ class Unique(cde.UniqueOp):
|
|||
>>> # +---------+-----------------+---------+
|
||||
|
||||
"""
|
||||
class Compose(cde.ComposeOp):
|
||||
class Compose():
|
||||
"""
|
||||
Compose a list of transforms into a single transform.
|
||||
|
||||
|
@ -344,10 +344,18 @@ class Compose(cde.ComposeOp):
|
|||
|
||||
@check_random_transform_ops
|
||||
def __init__(self, transforms):
|
||||
super().__init__(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def parse(self):
|
||||
operations = []
|
||||
for op in self.transforms:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
return cde.ComposeOperation(operations)
|
||||
|
||||
class RandomApply(cde.RandomApplyOp):
|
||||
class RandomApply():
|
||||
"""
|
||||
Randomly perform a series of transforms with a given probability.
|
||||
|
||||
|
@ -365,10 +373,20 @@ class RandomApply(cde.RandomApplyOp):
|
|||
|
||||
@check_random_transform_ops
|
||||
def __init__(self, transforms, prob=0.5):
|
||||
super().__init__(prob, transforms)
|
||||
self.transforms = transforms
|
||||
self.prob = prob
|
||||
|
||||
def parse(self):
|
||||
operations = []
|
||||
for op in self.transforms:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
return cde.RandomApplyOperation(self.prob, operations)
|
||||
|
||||
|
||||
class RandomChoice(cde.RandomChoiceOp):
|
||||
class RandomChoice():
|
||||
"""
|
||||
Randomly selects one transform from a list of transforms to perform operation.
|
||||
|
||||
|
@ -385,4 +403,13 @@ class RandomChoice(cde.RandomChoiceOp):
|
|||
|
||||
@check_random_transform_ops
|
||||
def __init__(self, transforms):
|
||||
super().__init__(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def parse(self):
|
||||
operations = []
|
||||
for op in self.transforms:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
return cde.RandomChoiceOperation(operations)
|
||||
|
|
|
@ -1325,14 +1325,14 @@ class UniformAugment(cde.UniformAugOp):
|
|||
super().__init__(transforms, num_ops)
|
||||
|
||||
|
||||
class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
|
||||
class RandomSelectSubpolicy():
|
||||
"""
|
||||
Choose a random sub-policy from a list to be applied on the input image. A sub-policy is a list of tuples
|
||||
(op, prob), where op is a TensorOp operation and prob is the probability that this op will be applied. Once
|
||||
a sub-policy is selected, each op within the subpolicy with be applied in sequence according to its probability.
|
||||
|
||||
Args:
|
||||
policy (list(list(tuple(TensorOp,float))): List of sub-policies to choose from.
|
||||
policy (list(list(tuple(TensorOp, float))): List of sub-policies to choose from.
|
||||
|
||||
Examples:
|
||||
>>> policy = [[(c_vision.RandomRotation((45, 45)), 0.5),
|
||||
|
@ -1346,7 +1346,22 @@ class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
|
|||
|
||||
@check_random_select_subpolicy_op
|
||||
def __init__(self, policy):
|
||||
super().__init__(policy)
|
||||
self.policy = policy
|
||||
|
||||
def parse(self):
|
||||
"""
|
||||
Return a C++ representation of the operator for execution
|
||||
"""
|
||||
policy = []
|
||||
for list_one in self.policy:
|
||||
policy_one = []
|
||||
for list_two in list_one:
|
||||
if hasattr(list_two[0], 'parse'):
|
||||
policy_one.append((list_two[0].parse(), list_two[1]))
|
||||
else:
|
||||
policy_one.append((list_two[0], list_two[1]))
|
||||
policy.append(policy_one)
|
||||
return cde.RandomSelectSubpolicyOperation(policy)
|
||||
|
||||
|
||||
class SoftDvppDecodeResizeJpeg(cde.SoftDvppDecodeResizeJpegOp):
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_compose():
|
|||
# Test exceptions.
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
|
||||
assert "op_list[0] is neither a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value)
|
||||
assert "op_list[0] is neither a c_transform op (TensorOperation) nor a callable pyfunc." in str(error_info.value)
|
||||
|
||||
# Test empty op list
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
|
|
|
@ -41,7 +41,7 @@ def test_random_select_subpolicy():
|
|||
# test exceptions
|
||||
assert "policy can not be empty." in test_config([[1, 2, 3]], [])
|
||||
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
|
||||
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOp) nor a callable pyfunc" \
|
||||
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOperation) nor a callable pyfunc" \
|
||||
in test_config([[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
|
||||
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
|
||||
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])
|
||||
|
|
Loading…
Reference in New Issue