!12521 [WIP] MindData C++ Plugin infrastructure with GDAL decode
From: @ziruiwu Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5b7ed56d57
|
@ -75,6 +75,7 @@ add_subdirectory(engine)
|
|||
add_subdirectory(api)
|
||||
add_subdirectory(text)
|
||||
add_subdirectory(callback)
|
||||
add_subdirectory(plugin)
|
||||
######################################################################
|
||||
add_dependencies(utils core)
|
||||
add_dependencies(kernels-image core)
|
||||
|
@ -103,6 +104,7 @@ add_dependencies(kernels-ir core)
|
|||
add_dependencies(kernels-ir-data core)
|
||||
add_dependencies(kernels-ir-vision core)
|
||||
|
||||
|
||||
if(ENABLE_ACL)
|
||||
add_dependencies(kernels-dvpp-image core dvpp-utils)
|
||||
endif()
|
||||
|
@ -156,6 +158,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:kernels-ir>
|
||||
$<TARGET_OBJECTS:kernels-ir-data>
|
||||
$<TARGET_OBJECTS:kernels-ir-vision>
|
||||
$<TARGET_OBJECTS:md_plugin>
|
||||
)
|
||||
|
||||
if(ENABLE_ACL)
|
||||
|
|
|
@ -19,9 +19,10 @@
|
|||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
#include "minddata/dataset/kernels/data/no_op.h"
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/vision_ir.h"
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -52,6 +53,17 @@ PYBIND_REGISTER(TensorOperation, 0, ([](const py::module *m) {
|
|||
py::arg("TensorOperation");
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
PluginOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::PluginOperation, TensorOperation, std::shared_ptr<transforms::PluginOperation>>(
|
||||
*m, "PluginOperation")
|
||||
.def(py::init<std::string, std::string, std::string>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(*m, "NoOp").def(py::init<>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ComposeOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<transforms::ComposeOperation, TensorOperation, std::shared_ptr<transforms::ComposeOperation>>(
|
||||
|
|
|
@ -154,7 +154,7 @@ Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type,
|
|||
}
|
||||
|
||||
RETURN_IF_NOT_OK((*out)->AllocateBuffer(length));
|
||||
int ret_code = memcpy_s((*out)->data_, length, src, length);
|
||||
int ret_code = memcpy_ss((*out)->data_, length, src, length);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor.");
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -91,7 +91,7 @@ inline void BitSet(uint32_t *bits, uint32_t bitMask) { *bits |= bitMask; }
|
|||
|
||||
inline void BitClear(uint32_t *bits, uint32_t bitMask) { *bits &= (~bitMask); }
|
||||
|
||||
constexpr int32_t kDeMaxDim = std::numeric_limits<int32_t>::max(); // 2147483647 or 2^32 -1
|
||||
constexpr int64_t kDeMaxDim = std::numeric_limits<int64_t>::max();
|
||||
constexpr int32_t kDeMaxRank = std::numeric_limits<int32_t>::max();
|
||||
constexpr int64_t kDeMaxFreq = std::numeric_limits<int64_t>::max(); // 9223372036854775807 or 2^(64-1)
|
||||
constexpr int64_t kDeMaxTopk = std::numeric_limits<int64_t>::max();
|
||||
|
|
|
@ -3,19 +3,24 @@ add_subdirectory(data)
|
|||
add_subdirectory(ir)
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
|
||||
set(COMMON_TENSOR_OPS
|
||||
data/compose_op.cc
|
||||
data/random_apply_op.cc
|
||||
data/random_choice_op.cc
|
||||
tensor_op.cc
|
||||
plugin_op.cc
|
||||
)
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
add_library(kernels OBJECT
|
||||
data/compose_op.cc
|
||||
data/random_apply_op.cc
|
||||
data/random_choice_op.cc
|
||||
${COMMON_TENSOR_OPS}
|
||||
c_func_op.cc
|
||||
py_func_op.cc
|
||||
tensor_op.cc)
|
||||
)
|
||||
target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
else()
|
||||
add_library(kernels OBJECT
|
||||
data/compose_op.cc
|
||||
data/random_apply_op.cc
|
||||
data/random_choice_op.cc
|
||||
tensor_op.cc)
|
||||
${COMMON_TENSOR_OPS})
|
||||
endif()
|
||||
|
|
|
@ -39,8 +39,10 @@
|
|||
#include "minddata/dataset/kernels/data/slice_op.h"
|
||||
#endif
|
||||
#include "minddata/dataset/kernels/data/type_cast_op.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/data/unique_op.h"
|
||||
#include "minddata/dataset/kernels/plugin_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
@ -271,6 +273,18 @@ Status TypeCastOperation::to_json(nlohmann::json *out_json) {
|
|||
Status UniqueOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> UniqueOperation::Build() { return std::make_shared<UniqueOp>(); }
|
||||
Status PluginOperation::ValidateParams() {
|
||||
std::string err_msg;
|
||||
err_msg += lib_path_.empty() ? "lib_path is empty, please specify a path to .so file. " : "";
|
||||
err_msg += func_name_.empty() ? "func_name_ is empty, please specify function name to load." : "";
|
||||
if (!err_msg.empty()) {
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<TensorOp> PluginOperation::Build() {
|
||||
return std::make_shared<PluginOp>(lib_path_, func_name_, user_args_);
|
||||
}
|
||||
#endif
|
||||
} // namespace transforms
|
||||
} // namespace dataset
|
||||
|
|
|
@ -41,6 +41,7 @@ constexpr char kRandomApplyOperation[] = "RandomApply";
|
|||
constexpr char kRandomChoiceOperation[] = "RandomChoice";
|
||||
constexpr char kTypeCastOperation[] = "TypeCast";
|
||||
constexpr char kUniqueOperation[] = "Unique";
|
||||
constexpr char kPluginOperation[] = "Plugin";
|
||||
|
||||
// Transform operations for performing data transformation.
|
||||
namespace transforms {
|
||||
|
@ -264,7 +265,28 @@ class UniqueOperation : public TensorOperation {
|
|||
|
||||
std::string Name() const override { return kUniqueOperation; }
|
||||
};
|
||||
|
||||
class PluginOperation : public TensorOperation {
|
||||
public:
|
||||
explicit PluginOperation(const std::string &lib_path, const std::string &func_name, const std::string &user_args)
|
||||
: lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {}
|
||||
|
||||
~PluginOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kPluginOperation; }
|
||||
|
||||
private:
|
||||
std::string lib_path_;
|
||||
std::string func_name_;
|
||||
std::string user_args_;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace transforms
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/plugin_op.h"
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/plugin/plugin_loader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
errno_t memcpy_s_loop(uchar *dest, size_t destMax, const uchar *src, size_t count) {
|
||||
int64_t step = 0;
|
||||
while (count >= SECUREC_MEM_MAX_LEN) {
|
||||
int ret_code = memcpy_s(dest + step * SECUREC_MEM_MAX_LEN, SECUREC_MEM_MAX_LEN, src + step * SECUREC_MEM_MAX_LEN,
|
||||
SECUREC_MEM_MAX_LEN);
|
||||
if (ret_code != 0) return ret_code;
|
||||
count -= SECUREC_MEM_MAX_LEN;
|
||||
step++;
|
||||
}
|
||||
return memcpy_s(dest + step * SECUREC_MEM_MAX_LEN, count, src + step * SECUREC_MEM_MAX_LEN, count);
|
||||
}
|
||||
|
||||
Status PluginOp::PluginToTensorRow(const std::vector<plugin::Tensor> &in_row, TensorRow *out_row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(out_row != nullptr && out_row->empty(), "null/empty out_row received!");
|
||||
out_row->reserve(in_row.size());
|
||||
for (const auto &tensor : in_row) {
|
||||
std::shared_ptr<Tensor> output;
|
||||
DataType tp = DataType(tensor.type_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.IsNumeric() && tp != DataType::DE_UNKNOWN, "Unsupported type: " + tensor.type_);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(tensor.shape_), tp, tensor.buffer_.data(), &output));
|
||||
out_row->emplace_back(output);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PluginOp::TensorRowToPlugin(const TensorRow &in_row, std::vector<plugin::Tensor> *out_row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(out_row != nullptr && out_row->empty(), "null/empty out_row received!");
|
||||
out_row->resize(in_row.size());
|
||||
for (size_t ind = 0; ind < in_row.size(); ind++) {
|
||||
plugin::Tensor &tensor = (*out_row)[ind];
|
||||
if (in_row[ind]->type().IsNumeric()) {
|
||||
dsize_t buffer_size = in_row[ind]->SizeInBytes();
|
||||
tensor.buffer_.resize(buffer_size);
|
||||
int ret_code = memcpy_s_loop(tensor.buffer_.data(), buffer_size, in_row[ind]->GetBuffer(), buffer_size);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor.");
|
||||
|
||||
} else { // string tensor, for now, only tensor with 1 string is supported!
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in_row[ind]->shape().NumOfElements() == 1,
|
||||
"String tensor with more than 1 element is not yet supported.");
|
||||
// get the first and only string in this tensor
|
||||
std::string str1(*(in_row[ind]->begin<std::string_view>()));
|
||||
tensor.buffer_.resize(str1.size());
|
||||
std::memcpy(tensor.buffer_.data(), str1.data(), str1.size());
|
||||
}
|
||||
tensor.shape_ = in_row[ind]->shape().AsVector();
|
||||
tensor.type_ = in_row[ind]->type().ToString();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PluginOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
// Compute should quit if init fails. Error code has already been logged, no need to repeat
|
||||
RETURN_IF_NOT_OK(init_code_);
|
||||
std::vector<plugin::Tensor> in_row, out_row;
|
||||
RETURN_IF_NOT_OK(TensorRowToPlugin(input, &in_row));
|
||||
plugin::Status rc = plugin_op_->Compute(&in_row, &out_row);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rc.IsOk(), rc.ToString());
|
||||
RETURN_IF_NOT_OK(PluginToTensorRow(out_row, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PluginOp::PluginOp(const std::string &lib_path, const std::string &func_name, const std::string &user_args)
|
||||
: lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {
|
||||
init_code_ = Init();
|
||||
}
|
||||
|
||||
Status PluginOp::Init() {
|
||||
plugin::PluginManagerBase *plugin;
|
||||
RETURN_IF_NOT_OK(PluginLoader::GetInstance()->LoadPlugin(lib_path_, &plugin));
|
||||
// casting a void pointer to specific type
|
||||
plugin_op_ = dynamic_cast<plugin::TensorOp *>(plugin->GetModule(func_name_));
|
||||
RETURN_UNEXPECTED_IF_NULL(plugin_op_);
|
||||
plugin::Status rc = plugin_op_->ParseSerializedArgs(user_args_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rc.IsOk(), rc.ToString());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_PLUGIN_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_PLUGIN_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/plugin/include/shared_include.h"
|
||||
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// a generalized plugin for TensorOp
|
||||
class PluginOp : public TensorOp {
|
||||
public:
|
||||
PluginOp(const std::string &lib_path, const std::string &func_name, const std::string &user_args);
|
||||
|
||||
~PluginOp() = default;
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
Status Init(); // load plugin module
|
||||
|
||||
std::string Name() const override { return kPluginOp; }
|
||||
|
||||
// helper function to convert between plugin Tensor and MindData Tensor
|
||||
static Status PluginToTensorRow(const std::vector<plugin::Tensor> &, TensorRow *);
|
||||
|
||||
static Status TensorRowToPlugin(const TensorRow &, std::vector<plugin::Tensor> *);
|
||||
|
||||
private:
|
||||
Status init_code_;
|
||||
plugin::TensorOp *plugin_op_;
|
||||
std::string lib_path_;
|
||||
std::string func_name_;
|
||||
std::string user_args_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_PLUGIN_OP_H_
|
|
@ -145,6 +145,7 @@ constexpr char kUniqueOp[] = "UniqueOp";
|
|||
// other
|
||||
constexpr char kCFuncOp[] = "CFuncOp";
|
||||
constexpr char kPyFuncOp[] = "PyFuncOp";
|
||||
constexpr char kPluginOp[] = "PluginOp";
|
||||
constexpr char kNoOp[] = "NoOp";
|
||||
|
||||
// A class that does a computation on a Tensor
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
file(GLOB _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
add_library(md_plugin OBJECT
|
||||
shared_lib_util.cc
|
||||
plugin_loader.cc
|
||||
)
|
|
@ -0,0 +1,151 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_INCLUDE_SHARED_INCLUDE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_INCLUDE_SHARED_INCLUDE_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
/***
|
||||
* This file is is complied with both MindData and plugin separately. Changing this file without compiling both
|
||||
* projects could lead to undefined behaviors.
|
||||
*/
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace plugin {
|
||||
// forward declares
|
||||
class PluginManagerBase;
|
||||
class MindDataManagerBase;
|
||||
// any plugin module is expected to expose these two functions as the entry point
|
||||
|
||||
/// \brief First handshake between plugin and MD.
|
||||
/// \param[in] MindDataManagerBase, a pointer for callback functions. (plugin can call MD function)
|
||||
/// \return status code, Status::OK() if function succeeds.
|
||||
extern "C" PluginManagerBase *GetInstance(MindDataManagerBase *);
|
||||
|
||||
/// \brief Definition of this function is expected to deallocate PluginManager
|
||||
/// \return void
|
||||
extern "C" void DestroyInstance();
|
||||
|
||||
/***
|
||||
* Tentative version rule for Plugin: X.Y.Z
|
||||
* X, major version, increment when additional file is included or other major changes
|
||||
* Y, minor version, increment when class/API are changed or other minor changes
|
||||
* Z, patch version, increment when bug fix is introduced or other patches
|
||||
*/
|
||||
static constexpr char kSharedIncludeVersion[] = "0.5.6";
|
||||
|
||||
/***
|
||||
* All derived classes defined in plugin side needs to inherit from this.
|
||||
*/
|
||||
class PluginBase {
|
||||
protected:
|
||||
virtual ~PluginBase() noexcept = default;
|
||||
};
|
||||
|
||||
/***
|
||||
* This class is used for callback. Functions defined in MindData can be exposed to plugin via this virtual class.
|
||||
* All derived classes of this have their definition on MindData side.
|
||||
*/
|
||||
class MindDataBase {
|
||||
protected:
|
||||
virtual ~MindDataBase() noexcept = default;
|
||||
};
|
||||
|
||||
/***
|
||||
* This is a simplified version of Status code. It intends to offer a simple <bool,string> return type. The syntax of
|
||||
* this class is modelled after existing Status code.
|
||||
*/
|
||||
class Status : PluginBase {
|
||||
public:
|
||||
static Status OK() noexcept { return Status(); }
|
||||
static Status ERROR(const std::string &msg) noexcept { return Status(msg); }
|
||||
Status(const Status &) = default;
|
||||
Status(Status &&) = default;
|
||||
|
||||
// helper functions
|
||||
bool IsOk() const noexcept { return success_; }
|
||||
const std::string &ToString() const noexcept { return status_msg_; }
|
||||
|
||||
private:
|
||||
Status() noexcept : success_(true) {}
|
||||
explicit Status(const std::string &msg) noexcept : success_(false), status_msg_(msg) {}
|
||||
const bool success_;
|
||||
const std::string status_msg_;
|
||||
};
|
||||
|
||||
/***
|
||||
* This is the interface through which MindData interacts with external .so files. There can only be 1 instance of
|
||||
* this class (hence the name Singleton) per so file. This class is the in-memory representation of each so file.
|
||||
* GetModule() returns class that contains runtime logic (e.g. GDALDecode). Raw pointer is used so that PluginManager
|
||||
* owns the lifetime of whatever objects it returns. MindData can not part-take in the memory management of plugin
|
||||
* objects. PluginManager is expected to be destroyed when DestroyInstance() is called.
|
||||
*/
|
||||
class PluginManagerBase : public PluginBase {
|
||||
public:
|
||||
virtual std::string GetPluginVersion() noexcept = 0;
|
||||
|
||||
virtual std::map<std::string, std::set<std::string>> GetModuleNames() noexcept = 0;
|
||||
|
||||
/// \brief return the module (e.g. a specific plugin tensor op) based on the module name. (names can be queried)
|
||||
/// \return pointer to the module. returns nullptr if module doesn't exist.
|
||||
virtual PluginBase *GetModule(const std::string &name) noexcept = 0;
|
||||
};
|
||||
|
||||
/***
|
||||
* This class is used to get functions (e.g. Log) from MindData.
|
||||
*/
|
||||
class MindDataManagerBase : public MindDataBase {
|
||||
public:
|
||||
virtual MindDataBase *GetModule(const std::string &name) noexcept = 0;
|
||||
};
|
||||
|
||||
/***
|
||||
* this struct is a Tensor in its simplest form, it is used to send Tensor data between MindData and Plugin.
|
||||
*/
|
||||
class Tensor : public PluginBase {
|
||||
public:
|
||||
std::vector<unsigned char> buffer_; // contains the actual content of tensor
|
||||
std::vector<int64_t> shape_; // shape of tensor, can be empty which means scalar
|
||||
std::vector<int64_t> offsets_; // store the offsets for only string Tensor
|
||||
std::string type_; // supported string literals "unknown", "bool", "int8", "uint8", "int16", "uint16", "int32",
|
||||
// "uint32", "int64", "uint64", "float16", "float32", "float64", "string"
|
||||
};
|
||||
|
||||
/***
|
||||
* This is plugin's TensorOp which resembles MindData's TensorOp. No exception is allowed. Each function needs to catch
|
||||
* all the errors thrown by 3rd party lib and if recovery isn't possible, return false and log the error. if MindData
|
||||
* sees an function returns false, it will quit immediately without any attempt to resolve the issue.
|
||||
*/
|
||||
class TensorOp : public PluginBase {
|
||||
public:
|
||||
/// \brief Parse input params for this op. This function will only be called once for the lifetime of this object.
|
||||
/// \return status code, Status::OK() if function succeeds.
|
||||
virtual Status ParseSerializedArgs(const std::string &) noexcept = 0;
|
||||
|
||||
/// \brief Perform operation on in_row and return out_row
|
||||
/// \param[in] in_row pointer to input tensor row
|
||||
/// \param[out] out_row pointer to output tensor row
|
||||
/// \return status code, Status::OK() if function succeeds.
|
||||
virtual Status Compute(std::vector<Tensor> *in_row, std::vector<Tensor> *out_row) noexcept = 0;
|
||||
};
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_INCLUDE_SHARED_INCLUDE_H_
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/plugin/plugin_loader.h"
|
||||
#include "minddata/dataset/plugin/shared_lib_util.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PluginLoader *PluginLoader::GetInstance() noexcept {
|
||||
static PluginLoader pl;
|
||||
return &pl;
|
||||
}
|
||||
|
||||
PluginLoader::~PluginLoader() {
|
||||
std::vector<std::string> keys;
|
||||
// get the keys from map, this is to avoid concurrent iteration and delete
|
||||
std::transform(plugins_.begin(), plugins_.end(), std::back_inserter(keys), [](const auto &p) { return p.first; });
|
||||
for (std::string &key : keys) {
|
||||
Status rc = UnloadPlugin(key);
|
||||
MSLOG_IF(ERROR, rc.IsError(), mindspore::NoExceptionType) << rc.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
// LoadPlugin() is NOT thread-safe. It is supposed to be called when Ops are being built. E.g. PluginOp should call this
|
||||
// within constructor instead of in its Compute() which is parallel.
|
||||
Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManagerBase **singleton_plugin) {
|
||||
RETURN_UNEXPECTED_IF_NULL(singleton_plugin);
|
||||
auto itr = plugins_.find(filename);
|
||||
// return ok if this module is already loaded
|
||||
if (itr != plugins_.end()) {
|
||||
*singleton_plugin = itr->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
// Open the .so file
|
||||
void *handle = SharedLibUtil::Load(filename);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr, "fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
|
||||
|
||||
// Load GetInstance function ptr from the so file, so needs to be compiled with -fPIC
|
||||
void *func_handle = SharedLibUtil::FindSym(handle, "GetInstance");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
|
||||
|
||||
// cast the returned function ptr of type void* to the type of GetInstance
|
||||
plugin::PluginManagerBase *(*get_instance)(plugin::MindDataManagerBase *) =
|
||||
reinterpret_cast<plugin::PluginManagerBase *(*)(plugin::MindDataManagerBase *)>(func_handle);
|
||||
RETURN_UNEXPECTED_IF_NULL(get_instance);
|
||||
|
||||
*singleton_plugin = get_instance(nullptr); // call function ptr to get instance
|
||||
RETURN_UNEXPECTED_IF_NULL(*singleton_plugin);
|
||||
|
||||
std::string v1 = (*singleton_plugin)->GetPluginVersion(), v2(plugin::kSharedIncludeVersion);
|
||||
|
||||
// Version check, if version are not the same, log the error and return fail
|
||||
if (v1 != v2) {
|
||||
std::string err_msg = "[Plugin Version Error] expected:" + v2 + ", received:" + v1 + " please recompile.";
|
||||
if (SharedLibUtil::Close(handle) != 0) err_msg += ("\ndlclose() error, err_msg:" + SharedLibUtil::ErrMsg() + ".");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
const std::map<std::string, std::set<std::string>> module_names = (*singleton_plugin)->GetModuleNames();
|
||||
for (auto &p : module_names) {
|
||||
std::string msg = "Plugin " + p.first + " has module:";
|
||||
MS_LOG(DEBUG) << std::accumulate(p.second.begin(), p.second.end(), msg,
|
||||
[](const std::string &msg, const std::string &nm) { return msg + " " + nm; });
|
||||
}
|
||||
|
||||
// save the name and handle
|
||||
plugins_.insert({filename, {*singleton_plugin, handle}});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PluginLoader::UnloadPlugin(const std::string &filename) {
|
||||
auto itr = plugins_.find(filename);
|
||||
RETURN_OK_IF_TRUE(itr == plugins_.end()); // return true if this plugin was never loaded or already removed
|
||||
|
||||
void *func_handle = SharedLibUtil::FindSym(itr->second.second, "DestroyInstance");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
|
||||
|
||||
void (*destroy_instance)() = reinterpret_cast<void (*)()>(func_handle);
|
||||
RETURN_UNEXPECTED_IF_NULL(destroy_instance);
|
||||
|
||||
destroy_instance();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(SharedLibUtil::Close(itr->second.second) == 0,
|
||||
"dlclose() error: " + SharedLibUtil::ErrMsg());
|
||||
|
||||
plugins_.erase(filename);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_PLUGIN_LOADER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_PLUGIN_LOADER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/plugin/include/shared_include.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// This class manages all MindData's plugins. It serves as the singleton that owns all plugins and bridge the gap
|
||||
// between C++ RAII and C style functions
|
||||
class PluginLoader {
|
||||
public:
|
||||
/// \brief Singleton getter,
|
||||
/// \return pointer to PluginLoader
|
||||
static PluginLoader *GetInstance() noexcept;
|
||||
|
||||
PluginLoader() = default;
|
||||
|
||||
/// \brief destructor, will call unload internally to unload all plugins managed by PluginLoader
|
||||
~PluginLoader();
|
||||
|
||||
/// \brief load an shared object (.so file) via dlopen() and return the ptr to the loaded file (singleton_plugin).
|
||||
/// \param[in] filename the full path to .so file
|
||||
/// \param[out] singleton_plugin pointer to the loaded file
|
||||
/// \return status code
|
||||
Status LoadPlugin(const std::string &filename, plugin::PluginManagerBase **singleton_plugin);
|
||||
|
||||
private:
|
||||
/// \brief Unload so file, internally will call dlclose() and delete its handle.
|
||||
/// \param[in] filename, the full path to .so file
|
||||
/// \return status code
|
||||
Status UnloadPlugin(const std::string &filename);
|
||||
|
||||
std::map<std::string, std::pair<plugin::PluginManagerBase *, void *>>
|
||||
plugins_; // key: path, val: plugin, dlopen handle
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_PLUGIN_LOADER_H_
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/plugin/shared_lib_util.h"
|
||||
#ifdef __linux__
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
#ifdef __linux__
|
||||
void *SharedLibUtil::Load(const std::string &name) { return dlopen(name.c_str(), RTLD_LAZY); }
|
||||
void *SharedLibUtil::FindSym(void *handle, const std::string &name) { return dlsym(handle, name.c_str()); }
|
||||
int32_t SharedLibUtil::Close(void *handle) { return dlclose(handle); }
|
||||
std::string SharedLibUtil::ErrMsg() { return std::string(dlerror()); }
|
||||
#else // MindData currently doesn't support loading shared library on platform that doesn't support dlopen
|
||||
void *SharedLibUtil::Load(const std::string &name) { return nullptr; }
|
||||
void *SharedLibUtil::FindSym(void *handle, const std::string &name) { return nullptr; }
|
||||
int32_t SharedLibUtil::Close(void *handle) { return -1; }
|
||||
std::string SharedLibUtil::ErrMsg() { return std::string("Plugin on non-Linux platform is not yet supported."); }
|
||||
#endif
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_SHARED_LIB_UTIL_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_SHARED_LIB_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// This class is a collection of util functions which aims at abstracting the dependency on OS
|
||||
class SharedLibUtil {
|
||||
public:
|
||||
static void *Load(const std::string &name);
|
||||
|
||||
static void *FindSym(void *handle, const std::string &name);
|
||||
|
||||
static int32_t Close(void *handle);
|
||||
|
||||
static std::string ErrMsg();
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PLUGIN_SHARED_LIB_UTIL_H_
|
|
@ -64,7 +64,7 @@ void Task::operator()() {
|
|||
if (rc_.StatusCode() == StatusCode::kMDNetWorkError) {
|
||||
MS_LOG(WARNING) << rc_;
|
||||
} else {
|
||||
MS_LOG(ERROR) << rc_;
|
||||
MS_LOG(ERROR) << "Task: " << my_name_ << " is terminated with err msg: " << rc_;
|
||||
}
|
||||
ShutdownGroup();
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .validators import check_num_classes, check_ms_type, check_fill_value, check_slice_option, check_slice_op, \
|
||||
check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops
|
||||
check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops, check_plugin
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
|
||||
|
||||
|
@ -30,6 +30,7 @@ class TensorOperation:
|
|||
"""
|
||||
Base class Tensor Ops
|
||||
"""
|
||||
|
||||
def __call__(self, *input_tensor_list):
|
||||
tensor_row = []
|
||||
for tensor in input_tensor_list:
|
||||
|
@ -37,7 +38,7 @@ class TensorOperation:
|
|||
tensor_row.append(cde.Tensor(np.asarray(tensor)))
|
||||
except RuntimeError:
|
||||
raise TypeError("Invalid user input. Got {}: {}, cannot be converted into tensor." \
|
||||
.format(type(tensor), tensor))
|
||||
.format(type(tensor), tensor))
|
||||
callable_op = cde.Execute(self.parse())
|
||||
output_tensor_list = callable_op(tensor_row)
|
||||
for i, element in enumerate(output_tensor_list):
|
||||
|
@ -392,6 +393,7 @@ class Unique(TensorOperation):
|
|||
>>> # +---------+-----------------+---------+
|
||||
|
||||
"""
|
||||
|
||||
def parse(self):
|
||||
return cde.UniqueOperation()
|
||||
|
||||
|
@ -474,3 +476,27 @@ class RandomChoice(TensorOperation):
|
|||
else:
|
||||
operations.append(op)
|
||||
return cde.RandomChoiceOperation(operations)
|
||||
|
||||
|
||||
class Plugin(TensorOperation):
|
||||
"""
|
||||
Plugin support for MindData. Use this class to dynamically load a .so file (shared library) and execute its symbols.
|
||||
|
||||
Args:
|
||||
lib_path (str): Path to .so file which is compiled to support MindData plugin.
|
||||
func_name (str): Name of the function to load from the .so file.
|
||||
user_args (str, optional): Serialized args to pass to the plugin. Only needed if "func_name" requires one.
|
||||
|
||||
Examples:
|
||||
>>> plugin = c_transforms.Plugin("pluginlib.so", "PluginDecode")
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=plugin)
|
||||
"""
|
||||
|
||||
@check_plugin
|
||||
def __init__(self, lib_path, func_name, user_args=None):
|
||||
self.lib_path = lib_path
|
||||
self.func_name = func_name
|
||||
self.user_args = str() if (user_args is None) else user_args
|
||||
|
||||
def parse(self):
|
||||
return cde.PluginOperation(self.lib_path, self.func_name, self.user_args)
|
||||
|
|
|
@ -272,7 +272,7 @@ def check_random_apply(method):
|
|||
for i, transform in enumerate(transforms):
|
||||
if str(transform).find("c_transform") >= 0:
|
||||
raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
|
||||
.format(i))
|
||||
.format(i))
|
||||
|
||||
if prob is not None:
|
||||
type_check(prob, (float, int,), "prob")
|
||||
|
@ -294,7 +294,24 @@ def check_transforms_list(method):
|
|||
for i, transform in enumerate(transforms):
|
||||
if str(transform).find("c_transform") >= 0:
|
||||
raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
|
||||
.format(i))
|
||||
.format(i))
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_plugin(method):
|
||||
"""Wrapper method to check the parameters of plugin."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[lib_path, func_name, user_args], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(lib_path, (str,), "lib_path")
|
||||
type_check(func_name, (str,), "func_name")
|
||||
if user_args is not None:
|
||||
type_check(user_args, (str,), "user_args")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -27,16 +27,15 @@
|
|||
namespace common = mindspore::common;
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestTensorShape : public UT::Common {
|
||||
public:
|
||||
MindDataTestTensorShape() = default;
|
||||
MindDataTestTensorShape() = default;
|
||||
};
|
||||
|
||||
|
||||
TEST_F(MindDataTestTensorShape, TestBasics) {
|
||||
std::vector<dsize_t> vec = {4, 5, 6};
|
||||
TensorShape t(vec);
|
||||
|
@ -111,7 +110,7 @@ TEST_F(MindDataTestTensorShape, TestUnknown) {
|
|||
|
||||
// Test materializing a TensorShape by calling method on a given column descriptor
|
||||
TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
||||
int32_t rank = 0; // not used
|
||||
int32_t rank = 0; // not used
|
||||
int32_t num_elements = 0;
|
||||
|
||||
// Has no shape
|
||||
|
@ -121,7 +120,7 @@ TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
|||
Status rc = c1.MaterializeTensorShape(num_elements, &generated_shape1);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
MS_LOG(INFO) << "generated_shape1: " << common::SafeCStr(generated_shape1.ToString()) << ".";
|
||||
ASSERT_EQ(TensorShape({4}),generated_shape1);
|
||||
ASSERT_EQ(TensorShape({4}), generated_shape1);
|
||||
|
||||
// Has shape <DIM_UNKNOWN> i.e. <*>
|
||||
TensorShape requested_shape2({TensorShape::kDimUnknown});
|
||||
|
@ -131,7 +130,7 @@ TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
|||
rc = c2.MaterializeTensorShape(num_elements, &generated_shape2);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
MS_LOG(INFO) << "generated_shape2: " << common::SafeCStr(generated_shape2.ToString()) << ".";
|
||||
ASSERT_EQ(TensorShape({5}),generated_shape2);
|
||||
ASSERT_EQ(TensorShape({5}), generated_shape2);
|
||||
|
||||
// Compute unknown dimension <*,4>
|
||||
TensorShape requested_shape3({TensorShape::kDimUnknown, 4});
|
||||
|
@ -141,7 +140,7 @@ TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
|||
rc = c3.MaterializeTensorShape(num_elements, &generated_shape3);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
MS_LOG(INFO) << "generated_shape3: " << common::SafeCStr(generated_shape3.ToString()) << ".";
|
||||
ASSERT_EQ(TensorShape({3,4}),generated_shape3);
|
||||
ASSERT_EQ(TensorShape({3, 4}), generated_shape3);
|
||||
|
||||
// Compute unknown dimension <3,*,4>
|
||||
TensorShape requested_shape4({3, TensorShape::kDimUnknown, 4});
|
||||
|
@ -151,7 +150,7 @@ TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
|||
rc = c4.MaterializeTensorShape(num_elements, &generated_shape4);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
MS_LOG(INFO) << "generated_shape4: " << common::SafeCStr(generated_shape4.ToString()) << ".";
|
||||
ASSERT_EQ(TensorShape({3,2,4}),generated_shape4);
|
||||
ASSERT_EQ(TensorShape({3, 2, 4}), generated_shape4);
|
||||
|
||||
// requested and generated should be the same! <2,3,4>
|
||||
TensorShape requested_shape5({2, 3, 4});
|
||||
|
@ -161,7 +160,7 @@ TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
|||
rc = c5.MaterializeTensorShape(num_elements, &generated_shape5);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
MS_LOG(INFO) << "generated_shape5: " << common::SafeCStr(generated_shape5.ToString()) << ".";
|
||||
ASSERT_EQ(requested_shape5,generated_shape5);
|
||||
ASSERT_EQ(requested_shape5, generated_shape5);
|
||||
|
||||
// expect fail due to multiple unknown dimensions
|
||||
TensorShape requested_shape6({2, TensorShape::kDimUnknown, TensorShape::kDimUnknown});
|
||||
|
@ -181,6 +180,5 @@ TEST_F(MindDataTestTensorShape, TestColDescriptor) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestTensorShape, TestInvalid) {
|
||||
ASSERT_EQ(TensorShape({2147483648}), TensorShape::CreateUnknownRankShape());
|
||||
ASSERT_EQ(TensorShape({kDeMaxDim - 1, kDeMaxDim - 1, kDeMaxDim - 1}), TensorShape::CreateUnknownRankShape());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue