forked from mindspore-Ecosystem/mindspore
Started vision decoupling
map validate params Changing bindings Change behavior of vision.cc Added signature and changed Cmake Compiling Need to fix compile Added compiling random transforms python changes for decoupling two failed test case remaining Compiling random choice passes all ut Changed assert review comments Added validate params add back return value Fix lint py lint fix pylint 2 Addressing comments
This commit is contained in:
parent
8938c2f5ee
commit
be46ccf721
|
@ -2,25 +2,27 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(APItoPython OBJECT
|
||||
python/pybind_register.cc
|
||||
python/pybind_conversion.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
||||
python/bindings/dataset/engine/gnn/bindings.cc
|
||||
python/bindings/dataset/include/datasets_bindings.cc
|
||||
python/bindings/dataset/include/iterator_bindings.cc
|
||||
python/bindings/dataset/include/execute_binding.cc
|
||||
python/bindings/dataset/include/schema_bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
python/bindings/dataset/kernels/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/bindings.cc
|
||||
python/bindings/dataset/engine/gnn/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
python/bindings/dataset/kernels/image/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
|
||||
python/bindings/dataset/kernels/ir/bindings.cc
|
||||
python/bindings/dataset/kernels/ir/image/bindings.cc
|
||||
python/bindings/dataset/text/bindings.cc
|
||||
python/bindings/dataset/text/kernels/bindings.cc
|
||||
python/bindings/mindrecord/include/bindings.cc
|
||||
python/pybind_conversion.cc
|
||||
python/pybind_register.cc
|
||||
)
|
||||
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
endif ()
|
||||
|
|
|
@ -49,42 +49,5 @@ PYBIND_REGISTER(TensorOp, 0, ([](const py::module *m) {
|
|||
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
|
||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ComposeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
return std::make_shared<ComposeOp>(t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(
|
||||
*m, "NoOp", "TensorOp that does nothing, for testing purposes only.")
|
||||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomChoiceOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
return std::make_shared<RandomChoiceOp>(t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomApplyOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
|
||||
.def(py::init([](double prob, const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||
if (prob < 0 || prob > 1) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
|
||||
}
|
||||
return std::make_shared<RandomApplyOp>(prob, t_ops);
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PyListToTensorOperations(const py::list &py_ops, std::vector<std::shared_ptr<TensorOperation>> *ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ops);
|
||||
for (auto op : py_ops) {
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
ops->emplace_back(std::make_shared<transforms::PreBuiltOperation>(op.cast<std::shared_ptr<TensorOp>>()));
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
ops->emplace_back(
|
||||
std::make_shared<transforms::PreBuiltOperation>(std::make_shared<PyFuncOp>(op.cast<py::function>())));
|
||||
} else if (py::isinstance<TensorOperation>(op)) {
|
||||
ops->emplace_back(op.cast<std::shared_ptr<TensorOperation>>());
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp, TensorOperation nor a pyfunc.");
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
|
||||
for (auto const &op : *ops) {
|
||||
RETURN_UNEXPECTED_IF_NULL(op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PYBIND_REGISTER(TensorOperation, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TensorOperation, std::shared_ptr<TensorOperation>>(*m, "TensorOperation");
|
||||
py::arg("TensorOperation");
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ComposeOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::ComposeOperation, TensorOperation, std::shared_ptr<transforms::ComposeOperation>>(
|
||||
*m, "ComposeOperation")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
|
||||
auto compose = std::make_shared<transforms::ComposeOperation>(std::move(t_ops));
|
||||
THROW_IF_ERROR(compose->ValidateParams());
|
||||
return compose;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::RandomChoiceOperation, TensorOperation,
|
||||
std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation")
|
||||
.def(py::init([](const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
|
||||
auto random_choice = std::make_shared<transforms::RandomChoiceOperation>(std::move(t_ops));
|
||||
THROW_IF_ERROR(random_choice->ValidateParams());
|
||||
return random_choice;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomApplyOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::RandomApplyOperation, TensorOperation,
|
||||
std::shared_ptr<transforms::RandomApplyOperation>>(*m, "RandomApplyOperation")
|
||||
.def(py::init([](double prob, const py::list &ops) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> t_ops;
|
||||
THROW_IF_ERROR(PyListToTensorOperations(ops, &t_ops));
|
||||
auto random_apply = std::make_shared<transforms::RandomApplyOperation>(std::move(t_ops), prob);
|
||||
THROW_IF_ERROR(random_apply->ValidateParams());
|
||||
return random_apply;
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/include/vision_lite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomSelectSubpolicyOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation,
|
||||
std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation")
|
||||
.def(py::init([](const py::list &py_policy) {
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy;
|
||||
for (auto &py_sub : py_policy) {
|
||||
cpp_policy.push_back({});
|
||||
for (auto handle : py_sub.cast<py::list>()) {
|
||||
py::tuple tp = handle.cast<py::tuple>();
|
||||
if (tp.is_none() || tp.size() != 2) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
|
||||
}
|
||||
std::shared_ptr<TensorOperation> t_op;
|
||||
if (py::isinstance<TensorOperation>(tp[0])) {
|
||||
t_op = (tp[0]).cast<std::shared_ptr<TensorOperation>>();
|
||||
} else if (py::isinstance<TensorOp>(tp[0])) {
|
||||
t_op = std::make_shared<transforms::PreBuiltOperation>((tp[0]).cast<std::shared_ptr<TensorOp>>());
|
||||
} else if (py::isinstance<py::function>(tp[0])) {
|
||||
t_op = std::make_shared<transforms::PreBuiltOperation>(
|
||||
std::make_shared<PyFuncOp>((tp[0]).cast<py::function>()));
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
Status(StatusCode::kUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
|
||||
}
|
||||
double prob = (tp[1]).cast<py::float_>();
|
||||
if (prob < 0 || prob > 1) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
|
||||
}
|
||||
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
|
||||
}
|
||||
}
|
||||
return std::make_shared<vision::RandomSelectSubpolicyOperation>(cpp_policy);
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -92,13 +92,20 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operat
|
|||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
|
||||
if (py::isinstance<TensorOperation>(op)) {
|
||||
vector.push_back(op.cast<std::shared_ptr<TensorOperation>>());
|
||||
} else {
|
||||
THROW_IF_ERROR([]() {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Error: tensor_op is not recognised (not TensorOp, TensorOperation and not pyfunc).");
|
||||
}());
|
||||
}
|
||||
}
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
|
@ -107,12 +114,15 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(py::list operat
|
|||
std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation) {
|
||||
std::shared_ptr<TensorOperation> op;
|
||||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(operation)) {
|
||||
if (py::isinstance<TensorOperation>(operation)) {
|
||||
op = operation.cast<std::shared_ptr<TensorOperation>>();
|
||||
} else if (py::isinstance<TensorOp>(operation)) {
|
||||
tensor_op = operation.cast<std::shared_ptr<TensorOp>>();
|
||||
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op."); }());
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op or TensorOperation."); }());
|
||||
}
|
||||
op = std::make_shared<transforms::PreBuiltOperation>(tensor_op);
|
||||
return op;
|
||||
}
|
||||
|
||||
|
|
|
@ -88,7 +88,9 @@ Status MapNode::ValidateParams() {
|
|||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
for (const auto &op : operations_) {
|
||||
RETURN_IF_NOT_OK(op->ValidateParams());
|
||||
}
|
||||
if (!input_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_));
|
||||
}
|
||||
|
|
|
@ -381,5 +381,5 @@ def check_gnn_list_or_ndarray(param, param_name):
|
|||
|
||||
def check_tensor_op(param, param_name):
|
||||
"""check whether param is a tensor op or a callable Python function"""
|
||||
if not isinstance(param, cde.TensorOp) and not callable(param):
|
||||
raise TypeError("{0} is neither a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))
|
||||
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
||||
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
||||
|
|
|
@ -2320,10 +2320,16 @@ class MapDataset(Dataset):
|
|||
|
||||
def parse(self, children=None):
|
||||
column_order = replace_none(self.column_order, [])
|
||||
operations = []
|
||||
for op in self.operations:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
|
||||
cc = self.cache.cache_client if self.cache else None
|
||||
callbacks = [cb.create_runtime_obj() for cb in self.callbacks] if self.callbacks else []
|
||||
return cde.MapNode(children[0], self.operations, self.input_columns, self.output_columns, column_order, cc,
|
||||
return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, column_order, cc,
|
||||
callbacks).SetNumWorkers(self.num_parallel_workers)
|
||||
|
||||
def get_args(self):
|
||||
|
|
|
@ -327,7 +327,7 @@ class Unique(cde.UniqueOp):
|
|||
>>> # +---------+-----------------+---------+
|
||||
|
||||
"""
|
||||
class Compose(cde.ComposeOp):
|
||||
class Compose():
|
||||
"""
|
||||
Compose a list of transforms into a single transform.
|
||||
|
||||
|
@ -344,10 +344,18 @@ class Compose(cde.ComposeOp):
|
|||
|
||||
@check_random_transform_ops
|
||||
def __init__(self, transforms):
|
||||
super().__init__(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def parse(self):
|
||||
operations = []
|
||||
for op in self.transforms:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
return cde.ComposeOperation(operations)
|
||||
|
||||
class RandomApply(cde.RandomApplyOp):
|
||||
class RandomApply():
|
||||
"""
|
||||
Randomly perform a series of transforms with a given probability.
|
||||
|
||||
|
@ -365,10 +373,20 @@ class RandomApply(cde.RandomApplyOp):
|
|||
|
||||
@check_random_transform_ops
|
||||
def __init__(self, transforms, prob=0.5):
|
||||
super().__init__(prob, transforms)
|
||||
self.transforms = transforms
|
||||
self.prob = prob
|
||||
|
||||
def parse(self):
|
||||
operations = []
|
||||
for op in self.transforms:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
return cde.RandomApplyOperation(self.prob, operations)
|
||||
|
||||
|
||||
class RandomChoice(cde.RandomChoiceOp):
|
||||
class RandomChoice():
|
||||
"""
|
||||
Randomly selects one transform from a list of transforms to perform operation.
|
||||
|
||||
|
@ -385,4 +403,13 @@ class RandomChoice(cde.RandomChoiceOp):
|
|||
|
||||
@check_random_transform_ops
|
||||
def __init__(self, transforms):
|
||||
super().__init__(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def parse(self):
|
||||
operations = []
|
||||
for op in self.transforms:
|
||||
if op and getattr(op, 'parse', None):
|
||||
operations.append(op.parse())
|
||||
else:
|
||||
operations.append(op)
|
||||
return cde.RandomChoiceOperation(operations)
|
||||
|
|
|
@ -1322,14 +1322,14 @@ class UniformAugment(cde.UniformAugOp):
|
|||
super().__init__(transforms, num_ops)
|
||||
|
||||
|
||||
class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
|
||||
class RandomSelectSubpolicy():
|
||||
"""
|
||||
Choose a random sub-policy from a list to be applied on the input image. A sub-policy is a list of tuples
|
||||
(op, prob), where op is a TensorOp operation and prob is the probability that this op will be applied. Once
|
||||
a sub-policy is selected, each op within the subpolicy with be applied in sequence according to its probability.
|
||||
|
||||
Args:
|
||||
policy (list(list(tuple(TensorOp,float))): List of sub-policies to choose from.
|
||||
policy (list(list(tuple(TensorOp, float))): List of sub-policies to choose from.
|
||||
|
||||
Examples:
|
||||
>>> policy = [[(c_vision.RandomRotation((45, 45)), 0.5),
|
||||
|
@ -1343,7 +1343,22 @@ class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
|
|||
|
||||
@check_random_select_subpolicy_op
|
||||
def __init__(self, policy):
|
||||
super().__init__(policy)
|
||||
self.policy = policy
|
||||
|
||||
def parse(self):
|
||||
"""
|
||||
Return a C++ representation of the operator for execution
|
||||
"""
|
||||
policy = []
|
||||
for list_one in self.policy:
|
||||
policy_one = []
|
||||
for list_two in list_one:
|
||||
if hasattr(list_two[0], 'parse'):
|
||||
policy_one.append((list_two[0].parse(), list_two[1]))
|
||||
else:
|
||||
policy_one.append((list_two[0], list_two[1]))
|
||||
policy.append(policy_one)
|
||||
return cde.RandomSelectSubpolicyOperation(policy)
|
||||
|
||||
|
||||
class SoftDvppDecodeResizeJpeg(cde.SoftDvppDecodeResizeJpegOp):
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_compose():
|
|||
# Test exceptions.
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
|
||||
assert "op_list[0] is neither a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value)
|
||||
assert "op_list[0] is neither a c_transform op (TensorOperation) nor a callable pyfunc." in str(error_info.value)
|
||||
|
||||
# Test empty op list
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
|
|
|
@ -41,7 +41,7 @@ def test_random_select_subpolicy():
|
|||
# test exceptions
|
||||
assert "policy can not be empty." in test_config([[1, 2, 3]], [])
|
||||
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
|
||||
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOp) nor a callable pyfunc" \
|
||||
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOperation) nor a callable pyfunc" \
|
||||
in test_config([[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
|
||||
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
|
||||
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])
|
||||
|
|
Loading…
Reference in New Issue