forked from mindspore-Ecosystem/mindspore
mindspore server inference
This commit is contained in:
parent
a58b1a1435
commit
dfbb232468
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_INCLUDE_MS_SESSION_H
|
||||||
|
#define MINDSPORE_INCLUDE_MS_SESSION_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include "include/ms_tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class FuncGraph;
|
||||||
|
namespace inference {
|
||||||
|
class MS_API MSSession {
|
||||||
|
public:
|
||||||
|
MSSession() = default;
|
||||||
|
|
||||||
|
static std::shared_ptr<MSSession> CreateSession(const std::string &device, uint32_t device_id);
|
||||||
|
|
||||||
|
virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0;
|
||||||
|
|
||||||
|
virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_MS_SESSION_H
|
|
@ -0,0 +1,69 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||||
|
#define MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "ir/dtype/type_id.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
namespace inference {
|
||||||
|
class MS_API MSTensor {
|
||||||
|
public:
|
||||||
|
MSTensor() = default;
|
||||||
|
// brief Create a MSTensor pointer.
|
||||||
|
//
|
||||||
|
// param data_type DataTypeId of tensor to be created.
|
||||||
|
// param shape Shape of tensor to be created.
|
||||||
|
// return MSTensor pointer.
|
||||||
|
static MSTensor *CreateTensor(TypeId data_type, const std::vector<int> &shape);
|
||||||
|
|
||||||
|
~MSTensor() = default;
|
||||||
|
|
||||||
|
virtual TypeId data_type() const = 0;
|
||||||
|
|
||||||
|
virtual TypeId set_data_type(const TypeId data_type) = 0;
|
||||||
|
|
||||||
|
virtual std::vector<int> shape() const = 0;
|
||||||
|
|
||||||
|
virtual size_t set_shape(const std::vector<int> &shape) = 0;
|
||||||
|
|
||||||
|
virtual int DimensionSize(size_t index) const = 0;
|
||||||
|
// brief Get number of element in MSTensor.
|
||||||
|
//
|
||||||
|
// return Number of element in MSTensor.
|
||||||
|
virtual int ElementsNum() const = 0;
|
||||||
|
|
||||||
|
virtual std::size_t hash() const = 0;
|
||||||
|
// brief Get byte size of data in MSTensor.
|
||||||
|
//
|
||||||
|
// return Byte size of data in MSTensor.
|
||||||
|
virtual size_t Size() const = 0;
|
||||||
|
// brief Get pointer of data in MSTensor.
|
||||||
|
//
|
||||||
|
// The data pointer can be used to both write or read data in MSTensor.
|
||||||
|
//
|
||||||
|
// return A pointer points to data in MSTensor.
|
||||||
|
virtual void *MutableData() const = 0;
|
||||||
|
};
|
||||||
|
using MultiTensor = std::vector<std::vector<std::shared_ptr<inference::MSTensor>>>;
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_
|
|
@ -237,3 +237,29 @@ if (ENABLE_MINDDATA)
|
||||||
add_subdirectory(mindrecord)
|
add_subdirectory(mindrecord)
|
||||||
add_subdirectory(dataset)
|
add_subdirectory(dataset)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
|
# build inference
|
||||||
|
set(LOAD_ONNX_SRC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc
|
||||||
|
)
|
||||||
|
add_library(inference SHARED
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
|
||||||
|
${LOAD_ONNX_SRC}
|
||||||
|
)
|
||||||
|
target_link_libraries(inference PRIVATE ${PYTHON_LIB} ${SECUREC_LIBRARY}
|
||||||
|
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||||
|
|
||||||
|
if (ENABLE_CPU)
|
||||||
|
target_link_libraries(inference PRIVATE mindspore::dnnl mindspore::mkldnn)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (USE_GLOG)
|
||||||
|
target_link_libraries(inference PRIVATE mindspore::glog)
|
||||||
|
else()
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
|
||||||
|
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
|
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
|
||||||
|
endif ()
|
||||||
|
endif()
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
file(GLOB_RECURSE _IR_SRC_LIST ./*.cc dtype/*.cc)
|
file(GLOB_RECURSE _IR_SRC_LIST ./*.cc dtype/*.cc)
|
||||||
|
file(GLOB_RECURSE _IR_LITE_SRC_FILES
|
||||||
|
./lite/tensor.cc
|
||||||
|
)
|
||||||
|
list(REMOVE_ITEM _IR_SRC_LIST ${_IR_LITE_SRC_FILES})
|
||||||
set_property(SOURCE ${_IR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_IR)
|
set_property(SOURCE ${_IR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_IR)
|
||||||
add_library(_mindspore_ir_obj OBJECT ${_IR_SRC_LIST})
|
add_library(_mindspore_ir_obj OBJECT ${_IR_SRC_LIST})
|
||||||
|
|
|
@ -34,65 +34,9 @@
|
||||||
|
|
||||||
#include "ir/base.h"
|
#include "ir/base.h"
|
||||||
#include "ir/named.h"
|
#include "ir/named.h"
|
||||||
|
#include "ir/dtype/type_id.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
//
|
|
||||||
// Supported meta type
|
|
||||||
//
|
|
||||||
enum TypeId : int {
|
|
||||||
kTypeUnknown = 0,
|
|
||||||
kMetaTypeBegin = kTypeUnknown,
|
|
||||||
kMetaTypeType, // Type
|
|
||||||
kMetaTypeAnything,
|
|
||||||
kMetaTypeObject,
|
|
||||||
kMetaTypeTypeType, // TypeType
|
|
||||||
kMetaTypeProblem,
|
|
||||||
kMetaTypeExternal,
|
|
||||||
kMetaTypeNone,
|
|
||||||
kMetaTypeNull,
|
|
||||||
kMetaTypeEllipsis,
|
|
||||||
kMetaTypeEnd,
|
|
||||||
//
|
|
||||||
// Object types
|
|
||||||
//
|
|
||||||
kObjectTypeBegin = kMetaTypeEnd,
|
|
||||||
kObjectTypeNumber,
|
|
||||||
kObjectTypeString,
|
|
||||||
kObjectTypeList,
|
|
||||||
kObjectTypeTuple,
|
|
||||||
kObjectTypeSlice,
|
|
||||||
kObjectTypeKeyword,
|
|
||||||
kObjectTypeTensorType,
|
|
||||||
kObjectTypeClass,
|
|
||||||
kObjectTypeDictionary,
|
|
||||||
kObjectTypeFunction,
|
|
||||||
kObjectTypeJTagged,
|
|
||||||
kObjectTypeSymbolicKeyType,
|
|
||||||
kObjectTypeEnvType,
|
|
||||||
kObjectTypeRefKey,
|
|
||||||
kObjectTypeRef,
|
|
||||||
kObjectTypeEnd,
|
|
||||||
//
|
|
||||||
// Number Types
|
|
||||||
//
|
|
||||||
kNumberTypeBegin = kObjectTypeEnd,
|
|
||||||
kNumberTypeBool,
|
|
||||||
kNumberTypeInt,
|
|
||||||
kNumberTypeInt8,
|
|
||||||
kNumberTypeInt16,
|
|
||||||
kNumberTypeInt32,
|
|
||||||
kNumberTypeInt64,
|
|
||||||
kNumberTypeUInt,
|
|
||||||
kNumberTypeUInt8,
|
|
||||||
kNumberTypeUInt16,
|
|
||||||
kNumberTypeUInt32,
|
|
||||||
kNumberTypeUInt64,
|
|
||||||
kNumberTypeFloat,
|
|
||||||
kNumberTypeFloat16,
|
|
||||||
kNumberTypeFloat32,
|
|
||||||
kNumberTypeFloat64,
|
|
||||||
kNumberTypeEnd
|
|
||||||
};
|
|
||||||
|
|
||||||
TypeId IntBitsToTypeId(const int nbits);
|
TypeId IntBitsToTypeId(const int nbits);
|
||||||
TypeId UIntBitsToTypeId(const int nbits);
|
TypeId UIntBitsToTypeId(const int nbits);
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
/**
|
||||||
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||||
|
*
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_
|
||||||
|
#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
//
|
||||||
|
// Supported meta type
|
||||||
|
//
|
||||||
|
enum TypeId : int {
|
||||||
|
kTypeUnknown = 0,
|
||||||
|
kMetaTypeBegin = kTypeUnknown,
|
||||||
|
kMetaTypeType, // Type
|
||||||
|
kMetaTypeAnything,
|
||||||
|
kMetaTypeObject,
|
||||||
|
kMetaTypeTypeType, // TypeType
|
||||||
|
kMetaTypeProblem,
|
||||||
|
kMetaTypeExternal,
|
||||||
|
kMetaTypeNone,
|
||||||
|
kMetaTypeNull,
|
||||||
|
kMetaTypeEllipsis,
|
||||||
|
kMetaTypeEnd,
|
||||||
|
//
|
||||||
|
// Object types
|
||||||
|
//
|
||||||
|
kObjectTypeBegin = kMetaTypeEnd,
|
||||||
|
kObjectTypeNumber,
|
||||||
|
kObjectTypeString,
|
||||||
|
kObjectTypeList,
|
||||||
|
kObjectTypeTuple,
|
||||||
|
kObjectTypeSlice,
|
||||||
|
kObjectTypeKeyword,
|
||||||
|
kObjectTypeTensorType,
|
||||||
|
kObjectTypeClass,
|
||||||
|
kObjectTypeDictionary,
|
||||||
|
kObjectTypeFunction,
|
||||||
|
kObjectTypeJTagged,
|
||||||
|
kObjectTypeSymbolicKeyType,
|
||||||
|
kObjectTypeEnvType,
|
||||||
|
kObjectTypeRefKey,
|
||||||
|
kObjectTypeRef,
|
||||||
|
kObjectTypeEnd,
|
||||||
|
//
|
||||||
|
// Number Types
|
||||||
|
//
|
||||||
|
kNumberTypeBegin = kObjectTypeEnd,
|
||||||
|
kNumberTypeBool,
|
||||||
|
kNumberTypeInt,
|
||||||
|
kNumberTypeInt8,
|
||||||
|
kNumberTypeInt16,
|
||||||
|
kNumberTypeInt32,
|
||||||
|
kNumberTypeInt64,
|
||||||
|
kNumberTypeUInt,
|
||||||
|
kNumberTypeUInt8,
|
||||||
|
kNumberTypeUInt16,
|
||||||
|
kNumberTypeUInt32,
|
||||||
|
kNumberTypeUInt64,
|
||||||
|
kNumberTypeFloat,
|
||||||
|
kNumberTypeFloat16,
|
||||||
|
kNumberTypeFloat32,
|
||||||
|
kNumberTypeFloat64,
|
||||||
|
kNumberTypeEnd
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_
|
|
@ -14,18 +14,18 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_MINNIE_H_
|
#ifndef MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_
|
||||||
#define MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_MINNIE_H_
|
#define MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ParamValueMinnie : public ParamValue {
|
class ParamValueLite : public ParamValue {
|
||||||
public:
|
public:
|
||||||
ParamValueMinnie() : tensor_addr_(nullptr), tensor_size_(0) {}
|
ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {}
|
||||||
virtual ~ParamValueMinnie() = default;
|
virtual ~ParamValueLite() = default;
|
||||||
|
|
||||||
size_t tensor_size() const { return tensor_size_; }
|
size_t tensor_size() const { return tensor_size_; }
|
||||||
void set_tensor_size(size_t size) { tensor_size_ = size; }
|
void set_tensor_size(size_t size) { tensor_size_ = size; }
|
||||||
|
@ -38,6 +38,6 @@ class ParamValueMinnie : public ParamValue {
|
||||||
size_t tensor_size_;
|
size_t tensor_size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>;
|
using ParamValueLitePtr = std::shared_ptr<ParamValueLite>;
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_MINNIE_H_
|
#endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_
|
|
@ -0,0 +1,152 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include "ir/lite/tensor.h"
|
||||||
|
#include "securec/include/securec.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace tensor {
|
||||||
|
#define kMaxMallocSize 1024 * 1024 * 100
|
||||||
|
Tensor::Tensor(const TypeId data_type, const std::vector<int> &shape) : MetaTensor(data_type, shape) {}
|
||||||
|
|
||||||
|
Tensor::Tensor(const TypePtr &type_ptr, const std::vector<int> &shape) : MetaTensor(type_ptr, shape) {}
|
||||||
|
|
||||||
|
Tensor::Tensor(const Tensor &tensor) : MetaTensor(tensor) {
|
||||||
|
this->data_type_ = tensor.data_type_;
|
||||||
|
this->shape_ = tensor.shape_;
|
||||||
|
auto ret = CopyTensorData(tensor);
|
||||||
|
if (0 != ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "CopyTensorData error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tensor::CopyTensorData(const Tensor &srcTensor) {
|
||||||
|
if (srcTensor.data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "data of srcTensor is nullptr";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
size_t data_size = this->Size();
|
||||||
|
MS_ASSERT(data_size == tensor.Size());
|
||||||
|
if (this->data_ == nullptr) {
|
||||||
|
if (data_size > kMaxMallocSize) {
|
||||||
|
MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
this->data_ = malloc(data_size);
|
||||||
|
}
|
||||||
|
memcpy_s(this->data_, data_size, tensor.data_, tensor.Size());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::~Tensor() {
|
||||||
|
if (nullptr != this->data_) {
|
||||||
|
free(this->data_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor &Tensor::operator=(const Tensor &tensor) {
|
||||||
|
if (&tensor == this) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
this->shape_ = tensor.shape_;
|
||||||
|
this->data_type_ = tensor.data_type_;
|
||||||
|
auto ret = CopyTensorData(tensor);
|
||||||
|
if (0 != ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "CopyTensorData error";
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Tensor::operator==(const Tensor &tensor) {
|
||||||
|
return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Tensor::operator==(const Value &other) const {
|
||||||
|
if (other.isa<Tensor>()) {
|
||||||
|
auto other_ = static_cast<const Tensor &>(other);
|
||||||
|
return *this == other_;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace tensor
|
||||||
|
|
||||||
|
namespace inference {
|
||||||
|
MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
|
||||||
|
return new Tensor(data_type, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::Tensor() { this->tensor_impl_ = std::make_shared<tensor::Tensor>(); }
|
||||||
|
|
||||||
|
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) {
|
||||||
|
this->tensor_impl_ = std::make_shared<tensor::Tensor>(data_type, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
|
||||||
|
|
||||||
|
TypeId Tensor::data_type() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->data_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeId Tensor::set_data_type(TypeId data_type) {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->set_data_type(data_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> Tensor::shape() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Tensor::set_shape(const std::vector<int> &shape) {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->set_shape(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tensor::DimensionSize(size_t index) const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->DimensionSize(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tensor::ElementsNum() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->ElementsNum();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t Tensor::hash() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->hash();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<tensor::Tensor> Tensor::tensor() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Tensor::Size() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->Size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void *Tensor::MutableData() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->data();
|
||||||
|
}
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,97 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_IR_LITE_TENSOR_H_
|
||||||
|
#define MINDSPORE_CCSRC_IR_LITE_TENSOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/meta_tensor.h"
|
||||||
|
#include "ir/dtype/type.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace tensor {
|
||||||
|
class Tensor : public MetaTensor {
|
||||||
|
public:
|
||||||
|
Tensor() : MetaTensor() {}
|
||||||
|
|
||||||
|
Tensor(const TypeId data_type, const std::vector<int> &shape);
|
||||||
|
|
||||||
|
Tensor(const TypePtr &type_ptr, const std::vector<int> &shape);
|
||||||
|
|
||||||
|
Tensor(const Tensor &tensor);
|
||||||
|
|
||||||
|
~Tensor();
|
||||||
|
|
||||||
|
int CopyTensorData(const Tensor &srcTensor);
|
||||||
|
|
||||||
|
MS_DECLARE_PARENT(Tensor, MetaTensor)
|
||||||
|
|
||||||
|
virtual Tensor &operator=(const Tensor &tensor);
|
||||||
|
|
||||||
|
virtual bool operator==(const Tensor &tensor);
|
||||||
|
|
||||||
|
bool operator==(const Value &other) const override;
|
||||||
|
|
||||||
|
size_t Size() const { return MetaTensor::ElementsNum() * GetTypeByte(TypeIdToType(this->data_type_)); }
|
||||||
|
|
||||||
|
void *Data() const { return data_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void *data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TensorPtr = std::shared_ptr<Tensor>;
|
||||||
|
} // namespace tensor
|
||||||
|
|
||||||
|
namespace inference {
|
||||||
|
class Tensor : public MSTensor {
|
||||||
|
public:
|
||||||
|
Tensor();
|
||||||
|
|
||||||
|
Tensor(TypeId data_type, const std::vector<int> &shape);
|
||||||
|
|
||||||
|
explicit Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr);
|
||||||
|
|
||||||
|
~Tensor() = default;
|
||||||
|
|
||||||
|
TypeId data_type() const override;
|
||||||
|
|
||||||
|
TypeId set_data_type(const TypeId data_type) override;
|
||||||
|
|
||||||
|
std::vector<int> shape() const override;
|
||||||
|
|
||||||
|
size_t set_shape(const std::vector<int> &shape) override;
|
||||||
|
|
||||||
|
int DimensionSize(size_t index) const override;
|
||||||
|
|
||||||
|
int ElementsNum() const override;
|
||||||
|
|
||||||
|
std::size_t hash() const override;
|
||||||
|
|
||||||
|
std::shared_ptr<tensor::Tensor> tensor() const;
|
||||||
|
|
||||||
|
size_t Size() const override;
|
||||||
|
|
||||||
|
void *MutableData() const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<tensor::Tensor> tensor_impl_;
|
||||||
|
};
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_IR_LITE_TENSOR_H_
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -505,4 +506,68 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
||||||
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.");
|
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.");
|
||||||
}));
|
}));
|
||||||
} // namespace tensor
|
} // namespace tensor
|
||||||
|
|
||||||
|
namespace inference {
|
||||||
|
MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
|
||||||
|
return new Tensor(data_type, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::Tensor() { this->tensor_impl_ = std::make_shared<tensor::Tensor>(); }
|
||||||
|
|
||||||
|
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) {
|
||||||
|
this->tensor_impl_ = std::make_shared<tensor::Tensor>(data_type, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
|
||||||
|
|
||||||
|
TypeId Tensor::data_type() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->data_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeId Tensor::set_data_type(TypeId data_type) {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->set_data_type(data_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> Tensor::shape() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Tensor::set_shape(const std::vector<int> &shape) {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->set_shape(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tensor::DimensionSize(size_t index) const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->DimensionSize(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tensor::ElementsNum() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->ElementsNum();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t Tensor::hash() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->hash();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<tensor::Tensor> Tensor::tensor() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Tensor::Size() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->data().nbytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
void *Tensor::MutableData() const {
|
||||||
|
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||||
|
return this->tensor_impl_->data_c(true);
|
||||||
|
}
|
||||||
|
} // namespace inference
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "Eigen/Core"
|
#include "Eigen/Core"
|
||||||
#include "device/device_address.h"
|
#include "device/device_address.h"
|
||||||
#include "ir/meta_tensor.h"
|
#include "ir/meta_tensor.h"
|
||||||
|
#include "include/ms_tensor.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
@ -215,6 +216,11 @@ class Tensor : public MetaTensor {
|
||||||
// return The pointer to the object
|
// return The pointer to the object
|
||||||
void *data_c(bool writable = false);
|
void *data_c(bool writable = false);
|
||||||
|
|
||||||
|
// brief Get Tensor data byte-size for c++ type
|
||||||
|
//
|
||||||
|
// return byte size of Tensor data
|
||||||
|
size_t Size() const { return this->data().nbytes(); }
|
||||||
|
|
||||||
// brief Get data type from tensor data.
|
// brief Get data type from tensor data.
|
||||||
//
|
//
|
||||||
// param buf The buffer info of the py::array data.
|
// param buf The buffer info of the py::array data.
|
||||||
|
@ -266,10 +272,45 @@ class Tensor : public MetaTensor {
|
||||||
std::string id_{""};
|
std::string id_{""};
|
||||||
DeviceAddressPtr device_address_{nullptr};
|
DeviceAddressPtr device_address_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
using TensorPtr = std::shared_ptr<Tensor>;
|
using TensorPtr = std::shared_ptr<Tensor>;
|
||||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||||
} // namespace tensor
|
} // namespace tensor
|
||||||
|
|
||||||
|
namespace inference {
|
||||||
|
class Tensor : public MSTensor {
|
||||||
|
public:
|
||||||
|
Tensor();
|
||||||
|
|
||||||
|
Tensor(TypeId data_type, const std::vector<int> &shape);
|
||||||
|
|
||||||
|
explicit Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr);
|
||||||
|
|
||||||
|
~Tensor() = default;
|
||||||
|
|
||||||
|
TypeId data_type() const override;
|
||||||
|
|
||||||
|
TypeId set_data_type(const TypeId data_type) override;
|
||||||
|
|
||||||
|
std::vector<int> shape() const override;
|
||||||
|
|
||||||
|
size_t set_shape(const std::vector<int> &shape) override;
|
||||||
|
|
||||||
|
int DimensionSize(size_t index) const override;
|
||||||
|
|
||||||
|
int ElementsNum() const override;
|
||||||
|
|
||||||
|
std::size_t hash() const override;
|
||||||
|
|
||||||
|
std::shared_ptr<tensor::Tensor> tensor() const;
|
||||||
|
|
||||||
|
size_t Size() const override;
|
||||||
|
|
||||||
|
void *MutableData() const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<tensor::Tensor> tensor_impl_;
|
||||||
|
};
|
||||||
|
} // namespace inference
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_IR_TENSOR_H_
|
#endif // MINDSPORE_CCSRC_IR_TENSOR_H_
|
||||||
|
|
|
@ -90,6 +90,30 @@ class OpIOInfo {
|
||||||
class OpInfo {
|
class OpInfo {
|
||||||
public:
|
public:
|
||||||
OpInfo() = default;
|
OpInfo() = default;
|
||||||
|
OpInfo(const OpInfo &opinfo) {
|
||||||
|
op_name_ = opinfo.op_name();
|
||||||
|
imply_type_ = opinfo.imply_type();
|
||||||
|
|
||||||
|
impl_path_ = opinfo.impl_path();
|
||||||
|
fusion_type_ = opinfo.fusion_type();
|
||||||
|
async_flag_ = opinfo.async_flag_;
|
||||||
|
binfile_name_ = opinfo.binfile_name_;
|
||||||
|
compute_cost_ = opinfo.compute_cost_;
|
||||||
|
kernel_name_ = opinfo.kernel_name();
|
||||||
|
partial_flag_ = opinfo.partial_flag_;
|
||||||
|
dynamic_format_ = opinfo.dynamic_format_;
|
||||||
|
op_pattern_ = opinfo.op_pattern();
|
||||||
|
for (auto attr : opinfo.attrs_ptr()) {
|
||||||
|
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
|
||||||
|
}
|
||||||
|
for (auto input : opinfo.inputs_ptr()) {
|
||||||
|
inputs_ptr_.push_back(std::make_shared<OpIOInfo>(*input));
|
||||||
|
}
|
||||||
|
for (auto output : opinfo.outputs_ptr()) {
|
||||||
|
outputs_ptr_.push_back(std::make_shared<OpIOInfo>(*output));
|
||||||
|
}
|
||||||
|
ref_infos_ = opinfo.ref_infos();
|
||||||
|
}
|
||||||
~OpInfo() = default;
|
~OpInfo() = default;
|
||||||
std::string op_name() const { return op_name_; }
|
std::string op_name() const { return op_name_; }
|
||||||
OpImplyType imply_type() const { return imply_type_; }
|
OpImplyType imply_type() const { return imply_type_; }
|
||||||
|
|
|
@ -29,7 +29,12 @@ class OpLib {
|
||||||
OpLib() = default;
|
OpLib() = default;
|
||||||
virtual ~OpLib() = default;
|
virtual ~OpLib() = default;
|
||||||
bool RegOp(const std::string &json_string, const std::string &impl_path);
|
bool RegOp(const std::string &json_string, const std::string &impl_path);
|
||||||
|
static void RegOpInfo(std::shared_ptr<OpInfo> opinfo) {
|
||||||
|
op_info_.emplace_back(opinfo);
|
||||||
|
return;
|
||||||
|
}
|
||||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
|
||||||
|
static const std::vector<std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static std::vector<std::shared_ptr<OpInfo>> op_info_;
|
static std::vector<std::shared_ptr<OpInfo>> op_info_;
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_OPLOADER_H
|
||||||
|
#define MINDSPORE_OPLOADER_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "kernel/oplib/oplib.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class OpInfoLoaderPy {
|
||||||
|
public:
|
||||||
|
OpInfoLoaderPy() = default;
|
||||||
|
|
||||||
|
~OpInfoLoaderPy() = default;
|
||||||
|
|
||||||
|
size_t GetAllOpsInfo() {
|
||||||
|
auto ops = OpLib::GetAllOpsInfo();
|
||||||
|
auto op_infos = new std::vector<OpInfo *>();
|
||||||
|
for (auto op_info : ops) {
|
||||||
|
auto new_op_info = new OpInfo(*op_info);
|
||||||
|
op_infos->emplace_back(new_op_info);
|
||||||
|
}
|
||||||
|
return (size_t)op_infos;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_OPLOADER_H
|
|
@ -1,76 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_MINNIE_TENSOR_MINNIE_H_
|
|
||||||
#define MINDSPORE_CCSRC_MINNIE_TENSOR_MINNIE_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "ir/meta_tensor.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace tensor {
|
|
||||||
// definition of Tensor Minnie
|
|
||||||
class TensorMinnie : public MetaTensor {
|
|
||||||
public:
|
|
||||||
TensorMinnie() : MetaTensor() {}
|
|
||||||
~TensorMinnie() override = default;
|
|
||||||
MS_DECLARE_PARENT(TensorMinnie, MetaTensor)
|
|
||||||
|
|
||||||
// brief Overloads operator = for TensorMinnie.
|
|
||||||
//
|
|
||||||
// The constructed TensorMinnie object has the same type and shape with tensor_base.
|
|
||||||
//
|
|
||||||
// param meta_tensor An existing TensorMinnie object.
|
|
||||||
virtual TensorMinnie &operator=(const TensorMinnie &tensor);
|
|
||||||
|
|
||||||
// brief Compares two TensorMinnie objects.
|
|
||||||
//
|
|
||||||
// The constructed TensorMinnie object has the same type and shape with tensor_base.
|
|
||||||
//
|
|
||||||
// param meta_tensor The TensorMinnie object to be compared.
|
|
||||||
// return true: If having same type and shape, return true, or return false.
|
|
||||||
virtual bool operator==(const TensorMinnie &tensor);
|
|
||||||
|
|
||||||
// brief Get the tensor's size for C++
|
|
||||||
//
|
|
||||||
// return size_t
|
|
||||||
size_t tensor_size() const { return tensor_size_; }
|
|
||||||
|
|
||||||
// brief Set Tensor data size for c++ type
|
|
||||||
void set_tensor_size(size_t size) { tensor_size_ = size; }
|
|
||||||
|
|
||||||
// brief Get Tensor data pointer for c++ type
|
|
||||||
//
|
|
||||||
// return The pointer to the object
|
|
||||||
void *tensor_addr() const { return tensor_addr_; }
|
|
||||||
|
|
||||||
// brief Set Tensor data pointer for c++ type
|
|
||||||
void set_tensor_addr(void *addr) { tensor_addr_ = addr; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// brief Data addr of the tensor.
|
|
||||||
void *tensor_addr_;
|
|
||||||
|
|
||||||
// brief Data size of the tensor.
|
|
||||||
size_t tensor_size_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using TensorMinniePtr = std::shared_ptr<TensorMinnie>;
|
|
||||||
} // namespace tensor
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_MINNIE_TENSOR_MINNIE_H_
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <pybind11/operators.h>
|
#include <pybind11/operators.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
#include "kernel/oplib/oplib.h"
|
#include "kernel/oplib/oplib.h"
|
||||||
|
#include "kernel/oplib/oploader.h"
|
||||||
#include "pipeline/pipeline.h"
|
#include "pipeline/pipeline.h"
|
||||||
#include "operator/composite/composite.h"
|
#include "operator/composite/composite.h"
|
||||||
#include "ir/signature.h"
|
#include "ir/signature.h"
|
||||||
|
@ -45,6 +46,7 @@ using PrimitivePy = mindspore::PrimitivePy;
|
||||||
using MetaFuncGraph = mindspore::MetaFuncGraph;
|
using MetaFuncGraph = mindspore::MetaFuncGraph;
|
||||||
using EventWriter = mindspore::summary::EventWriter;
|
using EventWriter = mindspore::summary::EventWriter;
|
||||||
using OpLib = mindspore::kernel::OpLib;
|
using OpLib = mindspore::kernel::OpLib;
|
||||||
|
using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
|
||||||
using ParallelContext = mindspore::parallel::ParallelContext;
|
using ParallelContext = mindspore::parallel::ParallelContext;
|
||||||
using CostModelContext = mindspore::parallel::CostModelContext;
|
using CostModelContext = mindspore::parallel::CostModelContext;
|
||||||
|
|
||||||
|
@ -325,4 +327,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
"Finalize gpu collective communication mode.");
|
"Finalize gpu collective communication mode.");
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||||
|
.def(py::init())
|
||||||
|
.def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "include/inference.h"
|
||||||
|
#include "session/session.h"
|
||||||
|
#include "utils/load_onnx/anf_converter.h"
|
||||||
|
#include "session/session_basic.h"
|
||||||
|
#include "session/session_factory.h"
|
||||||
|
#include "utils/base_ref_utils.h"
|
||||||
|
#include "kernel/oplib/oplib.h"
|
||||||
|
#ifdef ENABLE_D
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "session/ascend_session.h"
|
||||||
|
#else
|
||||||
|
#include "session/cpu_session.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
namespace mindspore::inference {
|
||||||
|
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||||
|
inference::Session::RegAllOp();
|
||||||
|
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||||
|
return anf_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||||
|
auto session = std::make_shared<inference::Session>();
|
||||||
|
auto ret = session->Init(device, device_id);
|
||||||
|
if (ret != 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Session::RegAllOp() {
|
||||||
|
static std::mutex init_mutex;
|
||||||
|
static bool Initialized = false;
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(init_mutex);
|
||||||
|
if (Initialized) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Initialized = true;
|
||||||
|
MsContext::GetInstance()->set_execution_mode(kGraphMode);
|
||||||
|
Py_Initialize();
|
||||||
|
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||||
|
if (c_expression == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||||
|
|
||||||
|
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||||
|
if (op_info_loader_class == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||||
|
if (op_info_loader == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||||
|
if (op_info_loader_ins == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||||
|
if (all_ops_info_vector_addr_ul == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||||
|
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||||
|
for (auto op_info : *all_ops_info) {
|
||||||
|
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||||
|
}
|
||||||
|
all_ops_info->clear();
|
||||||
|
delete all_ops_info;
|
||||||
|
Py_DECREF(op_info_loader);
|
||||||
|
Py_DECREF(op_info_loader_class);
|
||||||
|
Py_DECREF(c_expression_dict);
|
||||||
|
Py_DECREF(c_expression);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
||||||
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
|
return session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
||||||
|
std::vector<tensor::TensorPtr> inTensors;
|
||||||
|
bool has_error = false;
|
||||||
|
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
|
||||||
|
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
|
||||||
|
if (tensor_ptr == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
|
||||||
|
has_error = true;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
|
||||||
|
has_error = true;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor->tensor();
|
||||||
|
});
|
||||||
|
if (has_error) {
|
||||||
|
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
|
||||||
|
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> multiTensor;
|
||||||
|
return multiTensor;
|
||||||
|
}
|
||||||
|
VectorRef outputs;
|
||||||
|
session_impl_->RunGraph(graph_id, inTensors, &outputs);
|
||||||
|
|
||||||
|
return TransformVectorRefToMultiTensor(outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Session::Init(const std::string &device, uint32_t device_id) {
|
||||||
|
RegAllOp();
|
||||||
|
session_impl_ = session::SessionFactory::Get().Create(device);
|
||||||
|
if (session_impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
session_impl_->Init(device_id);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Session::Session() = default;
|
||||||
|
} // namespace mindspore::inference
|
|
@ -0,0 +1,50 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||||
|
#define MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "session/session_basic.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "include/inference.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace inference {
|
||||||
|
class Session : public MSSession {
|
||||||
|
public:
|
||||||
|
Session();
|
||||||
|
|
||||||
|
uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) override;
|
||||||
|
|
||||||
|
MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override;
|
||||||
|
|
||||||
|
int Init(const std::string &device, uint32_t device_id);
|
||||||
|
|
||||||
|
static void RegAllOp();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||||
|
std::vector<uint32_t> graph_id_;
|
||||||
|
};
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
|
@ -5,5 +5,11 @@ if (NOT ENABLE_GE)
|
||||||
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES})
|
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE _UTILS_LITE_SRC_FILES
|
||||||
|
./load_onnx/anf_converter.cc
|
||||||
|
./load_onnx/anf_model_parser.cc
|
||||||
|
)
|
||||||
|
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_LITE_SRC_FILES})
|
||||||
|
|
||||||
set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS)
|
set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS)
|
||||||
add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST})
|
add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST})
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "utils/base_ref_utils.h"
|
||||||
|
#include "include/ms_tensor.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref) {
|
||||||
|
std::vector<std::shared_ptr<inference::MSTensor>> msTensors;
|
||||||
|
if (utils::isa<VectorRef>(base_ref)) {
|
||||||
|
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||||
|
for (size_t i = 0; i < ref_list.size(); ++i) {
|
||||||
|
if (utils::isa<tensor::Tensor>(ref_list[i])) {
|
||||||
|
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||||
|
auto tensor = new inference::Tensor(tensor_ptr);
|
||||||
|
msTensors.emplace_back(std::shared_ptr<inference::MSTensor>(tensor));
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The output is not a tensor!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (utils::isa<tensor::Tensor>(base_ref)) {
|
||||||
|
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||||
|
auto tensor = new inference::Tensor(tensor_ptr);
|
||||||
|
msTensors.emplace_back(std::shared_ptr<inference::MSTensor>(tensor));
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||||
|
}
|
||||||
|
return msTensors;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> TransformVectorRefToMultiTensor(
|
||||||
|
const VectorRef &vector_ref) {
|
||||||
|
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> multiTensor;
|
||||||
|
for (size_t i = 0; i < vector_ref.size(); ++i) {
|
||||||
|
auto tensors = TransformBaseRefToMSTensor(vector_ref[i]);
|
||||||
|
multiTensor.emplace_back(tensors);
|
||||||
|
}
|
||||||
|
return multiTensor;
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -14,21 +14,17 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "minnie/tensor_minnie.h"
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "utils/base_ref.h"
|
||||||
|
#include "include/ms_tensor.h"
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||||
|
#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace tensor {
|
std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref);
|
||||||
TensorMinnie &TensorMinnie::operator=(const TensorMinnie &tensor) {
|
|
||||||
if (&tensor == this) {
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
this->tensor_addr_ = tensor.tensor_addr();
|
|
||||||
this->tensor_size_ = tensor.tensor_size();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool TensorMinnie::operator==(const TensorMinnie &tensor) {
|
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> TransformVectorRefToMultiTensor(
|
||||||
return tensor_addr_ == tensor.tensor_addr() && tensor_size_ == tensor.tensor_size();
|
const VectorRef &vector_ref);
|
||||||
}
|
|
||||||
} // namespace tensor
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
|
@ -0,0 +1,143 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
|
#include "utils/load_onnx/anf_model_parser.h"
|
||||||
|
#include "utils/load_onnx/anf_converter.h"
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "proto/onnx.pb.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
|
||||||
|
const char WHITESPACE[] = "\t\n\v\f\r ";
|
||||||
|
const int FLAG_PREFIX_LEN = 2;
|
||||||
|
|
||||||
|
void AnfConverter::Trim(std::string *input) {
|
||||||
|
if (input == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (input->empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
input->erase(0, input->find_first_not_of(WHITESPACE));
|
||||||
|
input->erase(input->find_last_not_of(WHITESPACE) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string fileType) {
|
||||||
|
if (modelFile.size() > fileType.size()) {
|
||||||
|
if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
|
||||||
|
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
|
||||||
|
int fd = open(onnx_file.get(), O_RDONLY);
|
||||||
|
google::protobuf::io::FileInputStream input(fd);
|
||||||
|
google::protobuf::io::CodedInputStream code_input(&input);
|
||||||
|
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
|
||||||
|
bool ret = onnx_model->ParseFromCodedStream(&code_input);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(ERROR) << "load onnx file failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(void)close(fd);
|
||||||
|
MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file_path) {
|
||||||
|
std::string modelFile;
|
||||||
|
|
||||||
|
std::string tmp = file_path;
|
||||||
|
Trim(&tmp);
|
||||||
|
const std::string flagItem(tmp);
|
||||||
|
|
||||||
|
size_t pos = flagItem.find_first_of("=");
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
MS_LOG(ERROR) << "Trans data not support input format!";
|
||||||
|
} else {
|
||||||
|
modelFile = flagItem.substr(pos + 1);
|
||||||
|
std::cout << "input protobuf file path is: " << flagItem.substr(pos + 1) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ValidateFileStr(modelFile, ".pb") != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "INPUT ILLEGAL: modelFile must be *.pb";
|
||||||
|
}
|
||||||
|
|
||||||
|
onnx::ModelProto model_;
|
||||||
|
ReadOnnxFromBinary(modelFile, &model_);
|
||||||
|
MSANFModelParser model_parser;
|
||||||
|
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||||
|
MS_EXCEPTION_IF_NULL(dstgraph_ptr);
|
||||||
|
TestFuncGraphBuild(dstgraph_ptr);
|
||||||
|
return dstgraph_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const size_t buf_size) {
|
||||||
|
Py_Initialize();
|
||||||
|
MS_EXCEPTION_IF_NULL(buf);
|
||||||
|
std::string str((const char *)buf, buf_size);
|
||||||
|
onnx::ModelProto model_;
|
||||||
|
if (!model_.ParseFromString(str)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parse model from buffer fail!";
|
||||||
|
}
|
||||||
|
MSANFModelParser model_parser;
|
||||||
|
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||||
|
MS_EXCEPTION_IF_NULL(dstgraph_ptr);
|
||||||
|
TestFuncGraphBuild(dstgraph_ptr);
|
||||||
|
return dstgraph_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AnfConverter::TestFuncGraphBuild(const FuncGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto node_return = graph->get_return();
|
||||||
|
std::vector<AnfNodePtr> node_list = TopoSort(node_return);
|
||||||
|
MS_LOG(INFO) << "node_list size is : " << node_list.size();
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
auto node_CN = node->cast<CNodePtr>();
|
||||||
|
MS_LOG(INFO) << "CN node: " << node_CN->input(0)->ToString() << ", input size :" << node_CN->size();
|
||||||
|
} else if (node->isa<Parameter>()) {
|
||||||
|
auto node_Para = node->cast<ParameterPtr>();
|
||||||
|
if (node_Para->has_default()) {
|
||||||
|
MS_LOG(INFO) << "Parameter node: " << node_Para->name() << "has default value!";
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "Parameter node: " << node_Para->name();
|
||||||
|
}
|
||||||
|
} else if (node->isa<ValueNode>()) {
|
||||||
|
auto node_Value = node->cast<ValueNodePtr>();
|
||||||
|
MS_LOG(INFO) << "Value node: " << node_Value->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H
|
||||||
|
#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "proto/onnx.pb.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class AnfConverter {
|
||||||
|
public:
|
||||||
|
static int TestFuncGraphBuild(const FuncGraphPtr &graph);
|
||||||
|
static std::shared_ptr<FuncGraph> RunAnfConverter(const std::string &file_path);
|
||||||
|
static std::shared_ptr<FuncGraph> RunAnfConverter(const char *buf, const size_t buf_size);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void Trim(std::string *input);
|
||||||
|
static int ValidateFileStr(const std::string &modelFile, std::string fileType);
|
||||||
|
static bool ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model);
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif
|
|
@ -0,0 +1,515 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "utils/load_onnx/anf_model_parser.h"
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
#include "ir/param_value_py.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
#include "proto/onnx.pb.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
static constexpr char kConstantValueNode[] = "Constant";
|
||||||
|
static constexpr char kCNodeShapeAttr[] = "shape";
|
||||||
|
enum ParseForm : int {
|
||||||
|
FORM_PARSE_TYPE = 0,
|
||||||
|
FORM_PARSE_SCALAR = 1,
|
||||||
|
FORM_PARSE_TENSOR = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
|
||||||
|
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
|
||||||
|
|
||||||
|
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||||
|
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||||
|
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||||
|
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||||
|
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||||
|
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||||
|
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||||
|
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||||
|
};
|
||||||
|
|
||||||
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||||
|
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
|
||||||
|
const onnx::TensorProto &attr_tensor) { \
|
||||||
|
MS_EXCEPTION_IF_NULL(prim); \
|
||||||
|
std::vector<valuetype> attr_value_vec; \
|
||||||
|
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
||||||
|
attr_value_vec.push_back(static_cast<valuetype>(attr_tensor.type##_data(i))); \
|
||||||
|
} \
|
||||||
|
if (attr_value_vec.size() == 1) { \
|
||||||
|
prim->AddAttr(attr_name, MakeValue<valuetype>(attr_value_vec[0])); \
|
||||||
|
} else { \
|
||||||
|
prim->AddAttr(attr_name, MakeValue<std::vector<valuetype>>(attr_value_vec)); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(string, string)
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32)
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||||
|
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||||
|
|
||||||
|
bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!value_proto.has_type() || !value_proto.has_name()) {
|
||||||
|
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
node->set_name(value_proto.name());
|
||||||
|
const auto &type_proto = value_proto.type();
|
||||||
|
if (!type_proto.has_tensor_type()) {
|
||||||
|
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
|
||||||
|
if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
|
||||||
|
MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
|
||||||
|
std::vector<int> shape;
|
||||||
|
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
|
||||||
|
shape.push_back(tensor_shape.dim(i).dim_value());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
|
||||||
|
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor::TensorPtr tensor_info =
|
||||||
|
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||||
|
auto tensor_abstract = tensor_info->ToAbstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||||
|
node->set_abstract(tensor_abstract);
|
||||||
|
|
||||||
|
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
|
||||||
|
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
|
||||||
|
std::string initial_data = initialize_proto.raw_data();
|
||||||
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true));
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||||
|
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
|
||||||
|
|
||||||
|
py::array array_data = tensor_info->data();
|
||||||
|
ParamValuePyPtr para_value_ptr = std::make_shared<ParamValuePy>();
|
||||||
|
MS_EXCEPTION_IF_NULL(para_value_ptr);
|
||||||
|
para_value_ptr->set_value(array_data);
|
||||||
|
node->set_default_param(para_value_ptr);
|
||||||
|
}
|
||||||
|
anfnode_build_map_[value_proto.name()] = node;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
|
const onnx::GraphProto &importProto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||||
|
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
|
||||||
|
|
||||||
|
for (int i = 0; i < importProto.initializer_size(); ++i) {
|
||||||
|
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
|
||||||
|
if (!initializer_proto.has_name()) {
|
||||||
|
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
default_para_map_[initializer_proto.name()] = initializer_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
|
||||||
|
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||||
|
const onnx::ValueInfoProto &input_proto = importProto.input(i);
|
||||||
|
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||||
|
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
switch (attr_tensor_type) {
|
||||||
|
case onnx::TensorProto_DataType_STRING: {
|
||||||
|
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_INT32: {
|
||||||
|
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_INT64: {
|
||||||
|
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_UINT64: {
|
||||||
|
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_FLOAT: {
|
||||||
|
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_DOUBLE: {
|
||||||
|
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_BOOL: {
|
||||||
|
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
|
||||||
|
auto value = prim->GetAttr(attr_name);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
const std::string &attr_name = attr_proto.name();
|
||||||
|
if (!attr_proto.has_ref_attr_name()) {
|
||||||
|
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||||
|
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||||
|
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||||
|
case FORM_PARSE_TYPE: {
|
||||||
|
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
||||||
|
}
|
||||||
|
case FORM_PARSE_SCALAR: {
|
||||||
|
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
|
||||||
|
}
|
||||||
|
case FORM_PARSE_TENSOR: {
|
||||||
|
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
std::vector<int> shape;
|
||||||
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||||
|
shape.push_back(attr_tensor.dims(i));
|
||||||
|
}
|
||||||
|
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||||
|
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||||
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true));
|
||||||
|
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||||
|
if (attr_tensor_type == onnx::TensorProto_DataType_FLOAT) {
|
||||||
|
auto *data_valuennode = reinterpret_cast<float *>(tensor_info->data_c());
|
||||||
|
MS_EXCEPTION_IF_NULL(data_valuennode);
|
||||||
|
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data_valuennode));
|
||||||
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
|
} else {
|
||||||
|
auto *data_valuenode = reinterpret_cast<int32 *>(tensor_info->data_c());
|
||||||
|
MS_EXCEPTION_IF_NULL(data_valuenode);
|
||||||
|
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data_valuenode));
|
||||||
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ObtainValueNodeInScalarForm(const std::string &value_node_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
ValuePtr value_ptr = nullptr;
|
||||||
|
switch (attr_tensor_type) {
|
||||||
|
case onnx::TensorProto_DataType_INT32: {
|
||||||
|
std::vector<int32> add_data;
|
||||||
|
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
|
||||||
|
add_data.push_back(attr_tensor.int32_data(i));
|
||||||
|
}
|
||||||
|
if (add_data.size() == 1) {
|
||||||
|
value_ptr = MakeValue(add_data[0]);
|
||||||
|
} else if (!add_data.empty()) {
|
||||||
|
value_ptr = MakeValue<std::vector<int32>>(add_data);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_FLOAT: {
|
||||||
|
std::vector<float> add_data;
|
||||||
|
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
|
||||||
|
add_data.push_back(attr_tensor.float_data(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (add_data.size() == 1) {
|
||||||
|
value_ptr = MakeValue(add_data[0]);
|
||||||
|
} else if (!add_data.empty()) {
|
||||||
|
value_ptr = MakeValue<std::vector<float>>(add_data);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case onnx::TensorProto_DataType_UNDEFINED: {
|
||||||
|
std::vector<ValuePtr> elems;
|
||||||
|
value_ptr = std::make_shared<ValueTuple>(elems);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto new_value_node = NewValueNode(value_ptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||||
|
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||||
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||||
|
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto new_value_node = std::make_shared<ValueNode>(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||||
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &ref_attr_name, const std::string &value_node_name,
|
||||||
|
const onnx::TensorProto &attr_tensor) {
|
||||||
|
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||||
|
case FORM_PARSE_SCALAR: {
|
||||||
|
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
|
||||||
|
}
|
||||||
|
case FORM_PARSE_TENSOR: {
|
||||||
|
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
||||||
|
}
|
||||||
|
case FORM_PARSE_TYPE: {
|
||||||
|
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||||
|
const std::string &value_node_name = node_proto.output(0);
|
||||||
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||||
|
if (!attr_proto.has_ref_attr_name()) {
|
||||||
|
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||||
|
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||||
|
|
||||||
|
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
|
||||||
|
std::vector<int> shape_vec;
|
||||||
|
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||||
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||||
|
shape_vec.push_back(attr_tensor.dims(i));
|
||||||
|
}
|
||||||
|
tensor::TensorPtr tensor_info =
|
||||||
|
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||||
|
return tensor_info->ToAbstract();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
|
||||||
|
const onnx::GraphProto &importProto, const bool &ret_flag) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||||
|
if (!node_proto.has_op_type()) {
|
||||||
|
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const std::string &node_name = node_proto.output(0);
|
||||||
|
const std::string &node_type = node_proto.op_type();
|
||||||
|
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
|
||||||
|
AbstractBasePtr abstract;
|
||||||
|
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||||
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||||
|
if (attr_proto.name() == kCNodeShapeAttr) {
|
||||||
|
abstract = GetAbstractForCNode(attr_proto);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Get CNode attr failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
inputs.clear();
|
||||||
|
inputs.push_back(NewValueNode(prim));
|
||||||
|
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||||
|
const std::string &input_name = node_proto.input(i);
|
||||||
|
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||||
|
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
inputs.push_back(anfnode_build_map_[input_name]);
|
||||||
|
}
|
||||||
|
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||||
|
cnode_ptr->set_abstract(abstract);
|
||||||
|
if (ret_flag) {
|
||||||
|
const onnx::ValueInfoProto &output_node = importProto.output(0);
|
||||||
|
const ::onnx::TypeProto &output_typeproto = output_node.type();
|
||||||
|
int output_type = output_typeproto.tensor_type().elem_type();
|
||||||
|
std::vector<int> output_shape;
|
||||||
|
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
|
||||||
|
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
|
||||||
|
}
|
||||||
|
tensor::TensorPtr tensor_return =
|
||||||
|
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[output_type], output_shape);
|
||||||
|
inputs.clear();
|
||||||
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||||
|
inputs.push_back(cnode_ptr);
|
||||||
|
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||||
|
return_node->set_abstract(tensor_return->ToAbstract());
|
||||||
|
outputFuncGraph->set_return(return_node);
|
||||||
|
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||||
|
}
|
||||||
|
anfnode_build_map_[node_name] = cnode_ptr;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||||
|
bool return_flag = false;
|
||||||
|
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||||
|
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||||
|
return_flag = (i == importProto.node_size() - 1) ? true : return_flag;
|
||||||
|
const onnx::NodeProto &node_proto = importProto.node(i);
|
||||||
|
const std::string &node_type = node_proto.op_type();
|
||||||
|
if (node_type == kConstantValueNode) {
|
||||||
|
if (!BuildValueNodeForFuncGraph(node_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!BuildCNodeForFuncGraph(outputFuncGraph, node_proto, importProto, return_flag)) {
|
||||||
|
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||||
|
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||||
|
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
||||||
|
if (importProto.has_name()) {
|
||||||
|
debug_info_ptr->set_name(importProto.name());
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return ImportNodesForGraph(outputFuncGraph, importProto);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||||
|
if (!model_proto.has_producer_name()) {
|
||||||
|
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
producer_name_ = model_proto.producer_name();
|
||||||
|
MS_LOG(INFO) << "producer_name :" << producer_name_;
|
||||||
|
|
||||||
|
if (!model_proto.has_producer_version()) {
|
||||||
|
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
producer_version_ = model_proto.producer_version();
|
||||||
|
MS_LOG(INFO) << "producer_version : " << producer_version_;
|
||||||
|
|
||||||
|
if (!model_proto.has_ir_version()) {
|
||||||
|
MS_LOG(ERROR) << "Parse model version from pb file failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ir_version_ = model_proto.ir_version();
|
||||||
|
MS_LOG(INFO) << "ir_version :" << ir_version_;
|
||||||
|
|
||||||
|
const onnx::OperatorSetIdProto &opset_proto = model_proto.opset_import(0);
|
||||||
|
if (!opset_proto.has_version()) {
|
||||||
|
MS_LOG(ERROR) << "Parse opset version from pb file failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
opset_version_ = opset_proto.version();
|
||||||
|
MS_LOG(INFO) << "opset_version : " << opset_version_;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) {
|
||||||
|
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||||
|
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||||
|
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const onnx::GraphProto &graphBuild = model_proto.graph();
|
||||||
|
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||||
|
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
|
||||||
|
return dstGraph;
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
|
||||||
|
#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "proto/onnx.pb.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
using int32 = int32_t;
|
||||||
|
using int64 = int64_t;
|
||||||
|
using uint64 = uint64_t;
|
||||||
|
class MSANFModelParser {
|
||||||
|
public:
|
||||||
|
MSANFModelParser() = default;
|
||||||
|
~MSANFModelParser() = default;
|
||||||
|
|
||||||
|
FuncGraphPtr Parse(const onnx::ModelProto &model_proto);
|
||||||
|
bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||||
|
|
||||||
|
std::string GetProducerName() { return producer_name_; }
|
||||||
|
std::string GetProducerVersion() { return producer_version_; }
|
||||||
|
int GetIrVersion() { return ir_version_; }
|
||||||
|
int GetOpsetVersion() { return opset_version_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||||
|
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||||
|
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||||
|
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||||
|
bool BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
|
||||||
|
const onnx::GraphProto &importProto, const bool &ret_flag);
|
||||||
|
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||||
|
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
const onnx::TensorProto &attr_tensor);
|
||||||
|
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
const onnx::TensorProto &attr_tensor);
|
||||||
|
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
const onnx::TensorProto &attr_tensor);
|
||||||
|
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
||||||
|
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||||
|
|
||||||
|
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||||
|
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
|
||||||
|
const onnx::TensorProto &attr_tensor);
|
||||||
|
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||||
|
|
||||||
|
std::string producer_name_;
|
||||||
|
std::string producer_version_;
|
||||||
|
int ir_version_{};
|
||||||
|
int opset_version_{};
|
||||||
|
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
||||||
|
std::map<std::string, onnx::TensorProto> default_para_map_;
|
||||||
|
|
||||||
|
AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
|
|
@ -106,9 +106,12 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
)
|
)
|
||||||
|
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/dump_proto.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/dump_proto.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ir/lite/tensor.cc")
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/anf_ir.pb.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/anf_ir.pb.cc")
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/node_strategy.pb.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/node_strategy.pb.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/load_onnx/anf_converter.cc")
|
||||||
|
|
||||||
file(GLOB_RECURSE UT_SUTB_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
file(GLOB_RECURSE UT_SUTB_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"stub/aicpu/*.cc"
|
"stub/aicpu/*.cc"
|
||||||
|
|
Loading…
Reference in New Issue