forked from mindspore-Ecosystem/mindspore
!48489 [MS]c_api support bool attr, tensor from file, save ir path
Merge pull request !48489 from liyejun/c_api_hl
This commit is contained in:
commit
3ded5cf3e3
|
@ -145,7 +145,7 @@ MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, c
|
|||
/// \param[in] value_num Size of the given array.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSOpGetAttrsInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
|
||||
MIND_C_API STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
|
||||
size_t value_num);
|
||||
|
||||
/// \brief Create new Int64 attribute scalar value.
|
||||
|
@ -164,6 +164,14 @@ MIND_C_API AttrHandle MSNewAttrInt64(ResMgrHandle res_mgr, const int64_t v);
|
|||
/// \return Attribute value handle.
|
||||
MIND_C_API AttrHandle MSNewAttrFloat32(ResMgrHandle res_mgr, const float v);
|
||||
|
||||
/// \brief Create new Bool attribute scalar value.
|
||||
///
|
||||
/// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
|
||||
/// \param[in] v Given value.
|
||||
///
|
||||
/// \return Attribute value handle.
|
||||
MIND_C_API AttrHandle MSNewAttrBool(ResMgrHandle res_mgr, const bool v);
|
||||
|
||||
/// \brief Create new attribute value with array.
|
||||
///
|
||||
/// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
|
||||
|
|
|
@ -72,6 +72,8 @@ MIND_C_API void MSSetDeviceId(uint32_t deviceId);
|
|||
/// 3: Full level. Save all IR graphs.
|
||||
MIND_C_API void MSSetGraphsSaveMode(int save_mode);
|
||||
|
||||
MIND_C_API void MSSetGraphsSavePath(const char *save_path);
|
||||
|
||||
/// \brief Set flag for auto shape and type infer
|
||||
///
|
||||
/// \param res_mgr Resource Handle that manages the nodes of the funcGraph.
|
||||
|
|
|
@ -42,6 +42,18 @@ extern "C" {
|
|||
MIND_C_API TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[],
|
||||
size_t shape_size, size_t data_len);
|
||||
|
||||
/// \brief Create a tensor with path to a space-sperated txt file.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
/// \param[in] type [TypeId] Data type of the tensor.
|
||||
/// \param[in] shape The shape arary of the tensor.
|
||||
/// \param[in] shape_size The size of shape array, i.e., the rank of the tensor.
|
||||
/// \param[in] path path to the file.
|
||||
///
|
||||
/// \return The pointer of the created tensor instance.
|
||||
MIND_C_API TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
const char *path);
|
||||
|
||||
/// \brief Create a tensor with input data buffer and given source data type.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
|
|
|
@ -240,7 +240,7 @@ int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *
|
|||
}
|
||||
}
|
||||
|
||||
STATUS MSOpGetAttrsInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
|
||||
STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
|
||||
size_t value_num) {
|
||||
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
|
||||
|
@ -288,6 +288,15 @@ AttrHandle MSNewAttrFloat32(ResMgrHandle res_mgr, const float v) {
|
|||
return GetRawPtr(res_mgr, value);
|
||||
}
|
||||
|
||||
AttrHandle MSNewAttrBool(ResMgrHandle res_mgr, const bool v) {
|
||||
if (res_mgr == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto value = std::make_shared<BoolImmImpl>(v);
|
||||
return GetRawPtr(res_mgr, value);
|
||||
}
|
||||
|
||||
AttrHandle MSOpNewAttrs(ResMgrHandle res_mgr, void *value, size_t vec_size, TypeId data_type) {
|
||||
if (res_mgr == nullptr || value == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [value_vec] is nullptr.";
|
||||
|
|
|
@ -85,6 +85,13 @@ void MSSetGraphsSaveMode(int save_mode) {
|
|||
return;
|
||||
}
|
||||
|
||||
void MSSetGraphsSavePath(const char *save_path) {
|
||||
MS_LOG(DEBUG) << "Set Graphs Save Path: " << save_path;
|
||||
auto context = mindspore::MsContext::GetInstance();
|
||||
context->set_param<std::string>(mindspore::MS_CTX_SAVE_GRAPHS_PATH, save_path);
|
||||
return;
|
||||
}
|
||||
|
||||
void MSSetInfer(ResMgrHandle res_mgr, bool infer) {
|
||||
MS_LOG(DEBUG) << "Set Infer Graph: " << infer;
|
||||
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
|
||||
|
|
|
@ -21,6 +21,24 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "ir/dtype.h"
|
||||
|
||||
template <typename T>
|
||||
void GetDataByFile(std::vector<T> data, const char *path, size_t *elem_size) {
|
||||
std::string fileName(path);
|
||||
MS_LOG(INFO) << "Reading File: " << fileName << std::endl;
|
||||
std::ifstream fin(fileName, std::ios::in);
|
||||
if (!fin.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file failed, File path: %s " << fileName << std::endl;
|
||||
return;
|
||||
}
|
||||
T t;
|
||||
while (fin >> t) {
|
||||
data.push_back(t);
|
||||
}
|
||||
fin.close();
|
||||
*elem_size = data.size() * sizeof(T);
|
||||
return;
|
||||
}
|
||||
|
||||
TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
size_t data_len) {
|
||||
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
|
||||
|
@ -38,6 +56,46 @@ TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const in
|
|||
return GetRawPtr(res_mgr, tensor);
|
||||
}
|
||||
|
||||
TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
const char *path) {
|
||||
if (res_mgr == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [shape] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
TensorPtr tensor = nullptr;
|
||||
ShapeVector shape_vec(shape, shape + shape_size);
|
||||
try {
|
||||
size_t data_len;
|
||||
switch (type) {
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
std::vector<int32_t> data;
|
||||
(void)GetDataByFile<int32_t>(data, path, &data_len);
|
||||
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt64: {
|
||||
std::vector<int64_t> data;
|
||||
(void)GetDataByFile<int64_t>(data, path, &data_len);
|
||||
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat32: {
|
||||
std::vector<float> data;
|
||||
(void)GetDataByFile<float>(data, path, &data_len);
|
||||
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ TypeId: " << type << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "New Tensor failed. Error info: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
return GetRawPtr(res_mgr, tensor);
|
||||
}
|
||||
|
||||
TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int64_t shape[], size_t shape_size,
|
||||
TypeId tensor_type, TypeId src_type) {
|
||||
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
|
||||
|
|
|
@ -64,7 +64,7 @@ TEST_F(TestCApiAttr, test_attr) {
|
|||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(attr1_retrived, 1);
|
||||
int64_t values[2];
|
||||
ret = MSOpGetAttrsInt64(res_mgr,op_add,"attr2",values,2);
|
||||
ret = MSOpGetAttrArrayInt64(res_mgr, op_add, "attr2", values, 2);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(values[0], 2);
|
||||
ASSERT_EQ(values[1], 2);
|
||||
|
@ -78,8 +78,10 @@ TEST_F(TestCApiAttr, test_attr) {
|
|||
values[1] = 1;
|
||||
ret = MSOpSetAttrArray(res_mgr, op_add, "attr2", values, 2, kNumberTypeInt64);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ret = MSOpGetAttrArrayInt64(res_mgr, op_add, "attr2", values, 2);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(values[0], 1);
|
||||
ASSERT_EQ(values[1], 1);
|
||||
MSResourceManagerDestroy(res_mgr);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue