forked from mindspore-Ecosystem/mindspore
!26926 Provide new mindspore core API classes
Merge pull request !26926 from hewei/core_api
This commit is contained in:
commit
95af195d0c
|
@ -221,6 +221,8 @@ if(PLATFORM_ARM64)
|
|||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
|
@ -298,6 +300,8 @@ elseif(PLATFORM_ARM32)
|
|||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
|
@ -365,6 +369,8 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
|
@ -409,6 +415,8 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
|
@ -430,6 +438,10 @@ else()
|
|||
PATTERN "train*" EXCLUDE PATTERN "delegate.h" EXCLUDE PATTERN "lite_session.h" EXCLUDE)
|
||||
install(FILES ${API_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${MINDAPI_BASE_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/base
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${MINDAPI_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/ir
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${ABSTRACT_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${API_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api/ir
|
||||
|
|
|
@ -23,6 +23,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"ir/*.cc"
|
||||
"utils/*.cc"
|
||||
"load_mindir/*.cc"
|
||||
"mindapi/src/*.cc"
|
||||
)
|
||||
|
||||
if(ENABLE_SECURITY)
|
||||
|
|
|
@ -24,8 +24,7 @@
|
|||
#include "utils/visible.h"
|
||||
#include "api/ir/func_graph_manager.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
|
||||
namespace mindspore::deprecated::api {
|
||||
/// \brief FuncGraph defines interface for a function graph.
|
||||
class MS_CORE_API FuncGraph {
|
||||
public:
|
||||
|
@ -147,5 +146,12 @@ class MS_CORE_API FuncGraph {
|
|||
/// \return The function graph if the input is value node that holds the graph, nullptr otherwise.
|
||||
static FuncGraphPtr GetFuncGraphFromAnfNode(const AnfNodePtr &input);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#ifndef USE_DEPRECATED_API
|
||||
#define USE_DEPRECATED_API
|
||||
namespace mindspore {
|
||||
namespace api = deprecated::api;
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore::deprecated::api
|
||||
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_H_
|
||||
|
|
|
@ -26,8 +26,7 @@
|
|||
#include "utils/hashing.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
|
||||
namespace mindspore::deprecated::api {
|
||||
class FuncGraph;
|
||||
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
|
||||
|
||||
|
@ -80,7 +79,13 @@ class MS_CORE_API FuncGraphManager {
|
|||
/// \return The manager that manages the given function graph.
|
||||
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
|
||||
};
|
||||
} // namespace mindspore::deprecated::api
|
||||
|
||||
} // namespace mindspore::api
|
||||
#ifndef USE_DEPRECATED_API
|
||||
#define USE_DEPRECATED_API
|
||||
namespace mindspore {
|
||||
namespace api = deprecated::api;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_MANAGER_H_
|
||||
|
|
|
@ -52,6 +52,7 @@ static const std::vector<std::string> sub_module_names = {
|
|||
"HCCL_ADPT", // SM_HCCL_ADPT
|
||||
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK
|
||||
"GE", // SM_GE
|
||||
"API", // SM_API
|
||||
};
|
||||
|
||||
const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_names[(module_id % NUM_SUBMODUES)]; }
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,87 +19,6 @@
|
|||
#ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
||||
#define MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
||||
|
||||
namespace mindspore {
|
||||
//
|
||||
// Supported meta type
|
||||
//
|
||||
enum TypeId : int {
|
||||
kTypeUnknown = 0,
|
||||
kMetaTypeBegin = kTypeUnknown,
|
||||
kMetaTypeType, // Type
|
||||
kMetaTypeAnything,
|
||||
kMetaTypeObject,
|
||||
kMetaTypeTypeType, // TypeType
|
||||
kMetaTypeProblem,
|
||||
kMetaTypeExternal,
|
||||
kMetaTypeNone,
|
||||
kMetaTypeNull,
|
||||
kMetaTypeEllipsis,
|
||||
kMetaTypeEnd,
|
||||
//
|
||||
// Object types
|
||||
//
|
||||
kObjectTypeBegin = kMetaTypeEnd,
|
||||
kObjectTypeNumber,
|
||||
kObjectTypeString,
|
||||
kObjectTypeList,
|
||||
kObjectTypeTuple,
|
||||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeRowTensorType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
kObjectTypeFunction,
|
||||
kObjectTypeJTagged,
|
||||
kObjectTypeSymbolicKeyType,
|
||||
kObjectTypeEnvType,
|
||||
kObjectTypeRefKey,
|
||||
kObjectTypeRef,
|
||||
kObjectTypeEnd,
|
||||
//
|
||||
// Number Types
|
||||
//
|
||||
kNumberTypeBegin = kObjectTypeEnd,
|
||||
kNumberTypeBool,
|
||||
kNumberTypeInt,
|
||||
kNumberTypeInt8,
|
||||
kNumberTypeInt16,
|
||||
kNumberTypeInt32,
|
||||
kNumberTypeInt64,
|
||||
kNumberTypeUInt,
|
||||
kNumberTypeUInt8,
|
||||
kNumberTypeUInt16,
|
||||
kNumberTypeUInt32,
|
||||
kNumberTypeUInt64,
|
||||
kNumberTypeFloat,
|
||||
kNumberTypeFloat16,
|
||||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeComplex,
|
||||
kNumberTypeComplex64,
|
||||
kNumberTypeComplex128,
|
||||
kNumberTypeInt4,
|
||||
kNumberTypeGLUInt,
|
||||
kNumberTypeEnd,
|
||||
//
|
||||
// Monad Types
|
||||
//
|
||||
kMonadTypeBegin = kNumberTypeEnd,
|
||||
kObjectTypeMonad,
|
||||
kObjectTypeUMonad,
|
||||
kObjectTypeIOMonad,
|
||||
kMonadTypeEnd,
|
||||
//
|
||||
// Sparse Types
|
||||
//
|
||||
// Sparse types is placed at the end of enum,
|
||||
// in order to keep fit with the type of existing model on the lite side.
|
||||
kSparseTypeBegin = kMonadTypeEnd,
|
||||
kObjectTypeCSRTensorType,
|
||||
kSparseTypeEnd
|
||||
};
|
||||
} // namespace mindspore
|
||||
#include "mindapi/base/type_id.h"
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
||||
|
|
|
@ -153,7 +153,7 @@ class FuncGraphBase : public Value {
|
|||
MS_DECLARE_PARENT(FuncGraphBase, Value);
|
||||
};
|
||||
|
||||
class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
|
||||
class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
|
||||
public:
|
||||
using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
|
||||
|
||||
|
@ -265,7 +265,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
|
|||
FuncGraphManagerPtr manager() const { return manager_.lock(); }
|
||||
void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
|
||||
|
||||
api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
|
||||
deprecated::api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
|
||||
|
||||
std::string ToString() const override;
|
||||
GraphDebugInfoPtr debug_info();
|
||||
|
|
|
@ -55,9 +55,9 @@ class FuncGraphTransaction;
|
|||
class FuncGraphManager;
|
||||
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
|
||||
|
||||
using AnfNodeIndexSet = api::AnfNodeIndexSet;
|
||||
using AnfNodeIndexSet = deprecated::api::AnfNodeIndexSet;
|
||||
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
|
||||
using NodeUsersMap = api::NodeUsersMap;
|
||||
using NodeUsersMap = deprecated::api::NodeUsersMap;
|
||||
using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
|
||||
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
|
||||
|
||||
|
@ -277,7 +277,8 @@ class FuncGraphJTotalComputer final : public DepComputer {
|
|||
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
|
||||
};
|
||||
|
||||
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
|
||||
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>,
|
||||
public deprecated::api::FuncGraphManager {
|
||||
public:
|
||||
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
|
||||
~FuncGraphManager() {
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_BASE_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_BASE_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "mindapi/base/macros.h"
|
||||
#include "mindapi/base/type_traits.h"
|
||||
#include "mindapi/base/shared_ptr.h"
|
||||
|
||||
namespace mindspore {
|
||||
class Base;
|
||||
}
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief Base is the base class of many api classes, which provides basic interfaces.
|
||||
class MIND_API Base {
|
||||
public:
|
||||
/// \brief Create an instance from the given implementation object.
|
||||
///
|
||||
/// \param[in] impl The shared_ptr to the implementation object.
|
||||
explicit Base(const std::shared_ptr<mindspore::Base> &impl);
|
||||
|
||||
/// \brief Destructor of Base.
|
||||
virtual ~Base() = default;
|
||||
|
||||
/// \brief Get the id of this class.
|
||||
///
|
||||
/// \return The id of this class.
|
||||
static uint32_t ClassId();
|
||||
|
||||
/// \brief Get the shared_ptr to the underly implementation object.
|
||||
///
|
||||
/// \return The shared_ptr to the underly implementation object.
|
||||
const std::shared_ptr<mindspore::Base> &impl() const { return impl_; }
|
||||
|
||||
/// \brief Get the string representation of this object.
|
||||
///
|
||||
/// \return The string representation.
|
||||
std::string ToString() const;
|
||||
|
||||
/// \brief Check whether this object is an instance of the given class.
|
||||
///
|
||||
/// \return True if this object is an instance of the given class, false otherwise.
|
||||
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>>
|
||||
inline bool isa() const {
|
||||
return IsFromClassId(T::ClassId());
|
||||
}
|
||||
|
||||
/// \brief Cast this object to a pointer with the given pointer class.
|
||||
///
|
||||
/// \return A non-null pointer if cast success, nullptr otherwise.
|
||||
template <typename T, typename U = typename std::enable_if_t<is_wrapper_ptr<T>::value, typename T::element_type>>
|
||||
inline T cast() {
|
||||
if (isa<U>()) {
|
||||
return MakeShared<U>(impl_);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool IsFromClassId(uint32_t class_id) const;
|
||||
const std::shared_ptr<mindspore::Base> impl_;
|
||||
};
|
||||
|
||||
#define MIND_API_BASE_MEMBER(current_class) \
|
||||
explicit current_class(const std::shared_ptr<mindspore::Base> &impl); \
|
||||
~current_class() override = default; \
|
||||
static uint32_t ClassId()
|
||||
|
||||
using BasePtr = SharedPtr<Base>;
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_BASE_H_
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include "mindapi/base/macros.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
enum class LogLevel : uint8_t { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION };
|
||||
|
||||
class LogWriterImpl;
|
||||
|
||||
/// \brief LogStream represents a stream to write log messages.
|
||||
/// This class is not expected for directly use, use MS_LOG instead.
|
||||
class LogStream {
|
||||
public:
|
||||
/// \brief Write log message to this LogStream.
|
||||
///
|
||||
/// \param[in] value The object to be written.
|
||||
template <typename T>
|
||||
LogStream &operator<<(T &&value) noexcept {
|
||||
(void)stream_.operator<<(std::forward<T>(value));
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class LogWriterImpl;
|
||||
std::stringstream stream_;
|
||||
};
|
||||
|
||||
/// \brief LogWriter defines interface for log message output.
|
||||
/// This class is not expected for directly use, use MS_LOG instead.
|
||||
class MIND_API LogWriter {
|
||||
public:
|
||||
/// \brief Create a LogWriter with the given log level, file name, line number and function name.
|
||||
///
|
||||
/// \param[in] level The log level.
|
||||
/// \param[in] file The file name.
|
||||
/// \param[in] line The line number.
|
||||
/// \param[in] func The function name.
|
||||
LogWriter(LogLevel level, const char *file, int line, const char *func);
|
||||
|
||||
/// \brief Destructor for LogWriter.
|
||||
~LogWriter();
|
||||
|
||||
/// \brief Output log message from the input log stream.
|
||||
///
|
||||
/// \param[in] stream The input log stream.
|
||||
void operator<(const LogStream &stream) const noexcept;
|
||||
|
||||
/// \brief Output log message from the input log stream and then throw exception.
|
||||
///
|
||||
/// \param[in] stream The input log stream.
|
||||
void operator^(const LogStream &stream) const __attribute__((noreturn));
|
||||
|
||||
/// \brief Check whether the given log level is enabled or not.
|
||||
///
|
||||
/// \return True if the log level is enabled, false otherwise.
|
||||
static bool IsEnabled(LogLevel level);
|
||||
|
||||
private:
|
||||
std::unique_ptr<LogWriterImpl> impl_;
|
||||
};
|
||||
|
||||
#define MIND_LOG_STREAM mindspore::api::LogStream()
|
||||
#define MIND_LOG_WRITER mindspore::api::LogWriter
|
||||
#define MIND_LOG_LEVEL(L) mindspore::api::LogLevel::L
|
||||
|
||||
#define MIND_LOG_THROW(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) ^ MIND_LOG_STREAM
|
||||
#define MIND_LOG_WRITE(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) < MIND_LOG_STREAM
|
||||
#define MIND_LOG_IF(L) \
|
||||
if (MIND_LOG_WRITER::IsEnabled(MIND_LOG_LEVEL(L))) MIND_LOG_WRITE(L)
|
||||
|
||||
#define MIND_LOG_DEBUG MIND_LOG_IF(DEBUG)
|
||||
#define MIND_LOG_INFO MIND_LOG_IF(INFO)
|
||||
#define MIND_LOG_WARNING MIND_LOG_IF(WARNING)
|
||||
#define MIND_LOG_ERROR MIND_LOG_IF(ERROR)
|
||||
#define MIND_LOG_EXCEPTION MIND_LOG_THROW(EXCEPTION)
|
||||
#define MIND_LOG(level) MIND_LOG_##level
|
||||
|
||||
#if !defined(MIND_LOG_NO_MS_LOG) && !defined(MS_LOG)
|
||||
#define MS_LOG(level) MIND_LOG_##level
|
||||
#endif
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_
|
||||
|
||||
#if (defined(_WIN32) || defined(__WIN32__) || defined(WIN32) || defined(__CYGWIN__))
|
||||
#ifdef BUILDING_DLL
|
||||
#define MIND_API __declspec(dllexport)
|
||||
#else
|
||||
#define MIND_API __declspec(dllimport)
|
||||
#endif
|
||||
#else
|
||||
#define MIND_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
using ShapeVector = std::vector<int64_t>;
|
||||
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_
|
|
@ -0,0 +1,180 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <ostream>
|
||||
#include <functional>
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief SharedPtr wraps a std::shared_ptr and provides wrapper functions according the underlying implementation.
|
||||
template <typename T>
|
||||
class SharedPtr {
|
||||
public:
|
||||
using element_type = T;
|
||||
constexpr SharedPtr() noexcept = default;
|
||||
constexpr SharedPtr(std::nullptr_t) noexcept : SharedPtr() {} // NOLINT
|
||||
template <typename U>
|
||||
explicit SharedPtr(std::shared_ptr<U> &&ptr) : ptr_(std::move(ptr)) {}
|
||||
template <typename U>
|
||||
SharedPtr(const SharedPtr<U> &other) : ptr_(other.ptr_) {}
|
||||
template <typename U>
|
||||
SharedPtr(SharedPtr<U> &&other) : ptr_(std::move(other.ptr_)) {}
|
||||
template <typename U>
|
||||
SharedPtr &operator=(const SharedPtr<U> &other) {
|
||||
ptr_ = other.ptr_;
|
||||
return *this;
|
||||
}
|
||||
template <typename U>
|
||||
SharedPtr &operator=(SharedPtr<U> &&other) {
|
||||
ptr_ = std::move(other.ptr_);
|
||||
return *this;
|
||||
}
|
||||
~SharedPtr() = default;
|
||||
|
||||
std::uintptr_t addr() const { return (ptr_ == nullptr) ? 0 : reinterpret_cast<std::uintptr_t>(ptr_->impl().get()); }
|
||||
element_type &operator*() const noexcept { return *ptr_; }
|
||||
element_type *operator->() const noexcept { return ptr_.get(); }
|
||||
element_type *get() const noexcept { return ptr_.get(); }
|
||||
explicit operator bool() const { return addr() != 0; }
|
||||
|
||||
private:
|
||||
template <typename U>
|
||||
friend class SharedPtr;
|
||||
std::shared_ptr<element_type> ptr_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool operator==(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
|
||||
return a.addr() == b.addr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator==(const SharedPtr<T> &a, std::nullptr_t) noexcept {
|
||||
return a.addr() == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator==(std::nullptr_t, const SharedPtr<T> &a) noexcept {
|
||||
return a.addr() == 0;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool operator!=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
|
||||
return a.addr() != b.addr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator!=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
|
||||
return a.addr() != 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator!=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
|
||||
return a.addr() != 0;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool operator<(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
|
||||
return a.addr() < b.addr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator<(const SharedPtr<T> &a, std::nullptr_t) noexcept {
|
||||
return a.addr() < 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator<(std::nullptr_t, const SharedPtr<T> &a) noexcept {
|
||||
// 'nullptr < ptr' is false only when ptr is nullptr.
|
||||
return a.addr() != 0;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool operator>(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
|
||||
return a.addr() > b.addr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator>(const SharedPtr<T> &a, std::nullptr_t) noexcept {
|
||||
return a.addr() > 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator>(std::nullptr_t, const SharedPtr<T> &a) noexcept {
|
||||
// 'nullptr > ptr' is always false.
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool operator<=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
|
||||
return a.addr() <= b.addr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator<=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
|
||||
return a.addr() <= 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator<=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
|
||||
// 'nullptr <= ptr' is always true.
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool operator>=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
|
||||
return a.addr() >= b.addr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator>=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
|
||||
return a.addr() >= 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool operator>=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
|
||||
// 'nullptr >= ptr' is true only when ptr is nullptr.
|
||||
return a.addr() == 0;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
inline std::basic_ostream<U, V> &operator<<(std::basic_ostream<U, V> &os, const SharedPtr<T> &a) {
|
||||
return (os << reinterpret_cast<void *>(a.addr()));
|
||||
}
|
||||
|
||||
/// \brief Constructs an object of type T and wraps it in a SharedPtr.
|
||||
///
|
||||
/// \param[in] args The parameter list for the constructor of T.
|
||||
template <typename T, typename... Args>
|
||||
inline SharedPtr<T> MakeShared(Args &&... args) {
|
||||
auto ptr = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
return SharedPtr<T>(std::move(ptr));
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
|
||||
namespace std {
|
||||
template <typename T>
|
||||
struct hash<mindspore::api::SharedPtr<T>> {
|
||||
size_t operator()(const mindspore::api::SharedPtr<T> &ptr) const noexcept { return static_cast<size_t>(ptr.addr()); }
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
|
||||
|
||||
namespace mindspore {
|
||||
/// \brief TypeId defines data type identifiers.
|
||||
enum TypeId : int {
|
||||
kTypeUnknown = 0,
|
||||
//
|
||||
// Meta types.
|
||||
//
|
||||
kMetaTypeBegin = kTypeUnknown,
|
||||
kMetaTypeType, // Type
|
||||
kMetaTypeAnything,
|
||||
kMetaTypeObject,
|
||||
kMetaTypeTypeType, // TypeType
|
||||
kMetaTypeProblem,
|
||||
kMetaTypeExternal,
|
||||
kMetaTypeNone,
|
||||
kMetaTypeNull,
|
||||
kMetaTypeEllipsis,
|
||||
kMetaTypeEnd,
|
||||
//
|
||||
// Object types
|
||||
//
|
||||
kObjectTypeBegin = kMetaTypeEnd,
|
||||
kObjectTypeNumber,
|
||||
kObjectTypeString,
|
||||
kObjectTypeList,
|
||||
kObjectTypeTuple,
|
||||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeRowTensorType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
kObjectTypeFunction,
|
||||
kObjectTypeJTagged,
|
||||
kObjectTypeSymbolicKeyType,
|
||||
kObjectTypeEnvType,
|
||||
kObjectTypeRefKey,
|
||||
kObjectTypeRef,
|
||||
kObjectTypeEnd,
|
||||
//
|
||||
// Number Types
|
||||
//
|
||||
kNumberTypeBegin = kObjectTypeEnd,
|
||||
kNumberTypeBool,
|
||||
kNumberTypeInt,
|
||||
kNumberTypeInt8,
|
||||
kNumberTypeInt16,
|
||||
kNumberTypeInt32,
|
||||
kNumberTypeInt64,
|
||||
kNumberTypeUInt,
|
||||
kNumberTypeUInt8,
|
||||
kNumberTypeUInt16,
|
||||
kNumberTypeUInt32,
|
||||
kNumberTypeUInt64,
|
||||
kNumberTypeFloat,
|
||||
kNumberTypeFloat16,
|
||||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeComplex,
|
||||
kNumberTypeComplex64,
|
||||
kNumberTypeComplex128,
|
||||
kNumberTypeInt4,
|
||||
kNumberTypeGLUInt,
|
||||
kNumberTypeEnd,
|
||||
//
|
||||
// Monad Types
|
||||
//
|
||||
kMonadTypeBegin = kNumberTypeEnd,
|
||||
kObjectTypeMonad,
|
||||
kObjectTypeUMonad,
|
||||
kObjectTypeIOMonad,
|
||||
kMonadTypeEnd,
|
||||
//
|
||||
// Sparse Types
|
||||
//
|
||||
// Sparse types is placed at the end of enum,
|
||||
// in order to keep fit with the type of existing model on the lite side.
|
||||
kSparseTypeBegin = kMonadTypeEnd,
|
||||
kObjectTypeCSRTensorType,
|
||||
kSparseTypeEnd
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include "mindapi/base/shared_ptr.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
template <typename T>
|
||||
struct is_wrapper_ptr : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_wrapper_ptr<SharedPtr<T>> : public std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_shared_ptr : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_vector : public std::false_type {};
|
||||
template <typename T, typename A>
|
||||
struct is_vector<std::vector<T, A>> : public std::true_type {};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
||||
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "mindapi/ir/shape.h"
|
||||
#include "mindapi/ir/type.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief AbstractBase defines base interfaces for abstract of an anf node.
|
||||
class MIND_API AbstractBase : public Base {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AbstractBase);
|
||||
|
||||
/// \brief Clone an abstract from this abstract.
|
||||
///
|
||||
/// \return A pointer to the cloned abstract.
|
||||
AbstractBasePtr Clone() const;
|
||||
|
||||
/// \brief Get the abstract type.
|
||||
///
|
||||
/// \return A pointer to the Type.
|
||||
TypePtr type() const;
|
||||
|
||||
/// \brief Get the abstract value.
|
||||
///
|
||||
/// \return A pointer to the Value.
|
||||
ValuePtr value() const;
|
||||
|
||||
/// \brief Set the type for this abstract.
|
||||
///
|
||||
/// \param[in] type The type to be set.
|
||||
void set_type(const TypePtr &type);
|
||||
|
||||
/// \brief Set the value for this abstract.
|
||||
///
|
||||
/// \param[in] value The value to be set.
|
||||
void set_value(const ValuePtr &value);
|
||||
};
|
||||
|
||||
/// \brief AbstractTensor describes a tensor's type, shape and value.
|
||||
class MIND_API AbstractTensor : public AbstractBase {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AbstractTensor);
|
||||
|
||||
/// \brief Create AbstractTensor from the given type and shape.
|
||||
///
|
||||
/// \param[in] type The data type id of the tensor.
|
||||
/// \param[in] shape The shape of the tensor.
|
||||
AbstractTensor(TypeId type, const ShapeVector &shape);
|
||||
|
||||
/// \brief Get the element abstract.
|
||||
///
|
||||
/// \return A pointer to the element abstract.
|
||||
AbstractBasePtr element() const;
|
||||
|
||||
/// \brief Get the shape of the abstract.
|
||||
///
|
||||
/// \return A pointer to the shape.
|
||||
ShapePtr shape() const;
|
||||
};
|
||||
|
||||
using AbstractTensorPtr = SharedPtr<AbstractTensor>;
|
||||
|
||||
/// \brief AbstractSequence describes the abstract for a tuple or list.
|
||||
class MIND_API AbstractSequence : public AbstractBase {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AbstractSequence);
|
||||
|
||||
/// \brief Get element abstracts.
|
||||
///
|
||||
/// \return A vector of element abstracts.
|
||||
AbstractBasePtrList elements() const;
|
||||
};
|
||||
|
||||
using AbstractSequencePtr = SharedPtr<AbstractSequence>;
|
||||
|
||||
/// \brief AbstractTuple describes the abstract for a tuple.
|
||||
class MIND_API AbstractTuple : public AbstractSequence {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AbstractTuple);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
|
@ -0,0 +1,232 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_ANF_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "mindapi/ir/abstract.h"
|
||||
#include "mindapi/ir/primitive.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief AnfNode is the basic class of the IR graph node.
|
||||
class MIND_API AnfNode : public Base {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AnfNode);
|
||||
|
||||
/// \brief Obtain detailed information about scope namespace.
|
||||
///
|
||||
/// \return Detailed information about scope namespace.
|
||||
std::string fullname_with_scope() const;
|
||||
|
||||
/// \brief Obtain the inferred abstract value of this AnfNode.
|
||||
///
|
||||
/// \return The inferred abstract value.
|
||||
AbstractBasePtr abstract() const;
|
||||
|
||||
/// \brief Set the abstract value of this AnfNode.
|
||||
///
|
||||
/// \param[in] abs New abstract value.
|
||||
void set_abstract(const AbstractBasePtr &abs);
|
||||
};
|
||||
|
||||
/// \brief CNode represents a compute node with a set of input nodes.
|
||||
class MIND_API CNode : public AnfNode {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CNode);
|
||||
|
||||
/// \brief Get the number of inputs.
|
||||
///
|
||||
/// \return The number of inputs in this CNode.
|
||||
size_t size() const;
|
||||
|
||||
/// \brief Get the input node of the given index.
|
||||
///
|
||||
/// \param[in] i The given index.
|
||||
///
|
||||
/// \return The input node of the given index.
|
||||
AnfNodePtr input(size_t i) const;
|
||||
|
||||
/// \brief Get the input nodes.
|
||||
///
|
||||
/// \return The input nodes of this CNode.
|
||||
std::vector<AnfNodePtr> inputs() const;
|
||||
|
||||
/// \brief Set the input nodes for this CNode.
|
||||
///
|
||||
/// \param[in] inputs Input nodes.
|
||||
void set_inputs(const std::vector<AnfNodePtr> &inputs);
|
||||
|
||||
/// \brief Add an input node to this CNode.
|
||||
///
|
||||
/// \param[in] input the input node to be added.
|
||||
void add_input(const AnfNodePtr &input);
|
||||
|
||||
/// \brief Set fullname_with_scope for this CNode.
|
||||
///
|
||||
/// \param[in] full_name The fullname_with_scope.
|
||||
void set_fullname_with_scope(const std::string &full_name);
|
||||
|
||||
/// \brief Add a new attribute to this CNode.
|
||||
///
|
||||
/// \param[in] name The name of the new attribute.
|
||||
/// \param[in] attr The value of the new attribute.
|
||||
void AddAttr(const std::string &name, const ValuePtr &attr);
|
||||
|
||||
/// \brief Erase the attribute with the given name.
|
||||
///
|
||||
/// \param[in] name The name of attribute.
|
||||
void EraseAttr(const std::string &name);
|
||||
|
||||
/// \brief Get the attribute with the given name.
|
||||
///
|
||||
/// \param[in] name The name of attribute.
|
||||
/// \return Attribute.
|
||||
ValuePtr GetAttr(const std::string &name) const;
|
||||
};
|
||||
|
||||
using CNodePtr = SharedPtr<CNode>;
|
||||
|
||||
/// \brief Parameter represents the parameter inputs of a function.
|
||||
class MIND_API Parameter : public AnfNode {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Parameter);
|
||||
|
||||
/// \brief Get the name of this Parameter.
|
||||
///
|
||||
/// \return The name.
|
||||
std::string name() const;
|
||||
|
||||
/// \brief Set the name of this Parameter.
|
||||
///
|
||||
/// \param[in] The name.
|
||||
void set_name(const std::string &name);
|
||||
|
||||
/// \brief Check if there is a default parameter.
|
||||
///
|
||||
/// \return True if this Parameter has a default parameter, otherwise false.
|
||||
bool has_default() const;
|
||||
|
||||
/// \brief Set the default parameter.
|
||||
///
|
||||
/// \param[in] param The default parameter.
|
||||
void set_default_param(const ValuePtr ¶m);
|
||||
|
||||
/// \brief Get the default parameter.
|
||||
///
|
||||
/// \return The default parameter.
|
||||
ValuePtr default_param() const;
|
||||
};
|
||||
|
||||
using ParameterPtr = SharedPtr<Parameter>;
|
||||
|
||||
/// \brief ValueNode is a graph node that hold a value.
|
||||
class MIND_API ValueNode : public AnfNode {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ValueNode);
|
||||
|
||||
/// \brief Create ValueNode with the given value.
|
||||
///
|
||||
/// \param[in] value The value of this ValueNode.
|
||||
explicit ValueNode(const ValuePtr &value);
|
||||
|
||||
/// \brief Get the value of this ValueNode.
|
||||
///
|
||||
/// \return The value.
|
||||
ValuePtr value() const;
|
||||
};
|
||||
|
||||
using ValueNodePtr = SharedPtr<ValueNode>;
|
||||
|
||||
// === ANF utility functions === //
|
||||
|
||||
/// \brief Create a ValueNode with the given value.
|
||||
///
|
||||
/// \param[in] value The given value.
|
||||
///
|
||||
/// \return The created ValueNode.
|
||||
inline ValueNodePtr NewValueNode(const ValuePtr &value) { return MakeShared<ValueNode>(value); }
|
||||
|
||||
/// \brief Create a ValueNode with the given primitive type value.
|
||||
///
|
||||
/// \param[in] value The given primitive type value.
|
||||
///
|
||||
/// \return The created ValueNode.
|
||||
template <typename T>
|
||||
inline ValueNodePtr NewValueNode(T value) {
|
||||
return NewValueNode(MakeValue(value));
|
||||
}
|
||||
|
||||
/// \brief Get the value from a node if it is a ValueNode.
|
||||
///
|
||||
/// \param[in] node The node which may hold a value.
|
||||
///
|
||||
/// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set.
|
||||
inline ValuePtr GetValueNode(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return value_node->value();
|
||||
}
|
||||
|
||||
/// \brief Get the value with the given type from a node if it is a ValueNode.
|
||||
///
|
||||
/// \param[in] node The node which may hold a value.
|
||||
///
|
||||
/// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set, or value type is mismatch.
|
||||
template <typename T, typename = typename std::enable_if_t<
|
||||
is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>>
|
||||
inline T GetValueNode(const AnfNodePtr &node) {
|
||||
auto value = GetValueNode(node);
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return value->cast<T>();
|
||||
}
|
||||
|
||||
/// \brief Check whether the given node is a cnode with the given Primitive as the first input.
|
||||
///
|
||||
/// \param[in] node The given node to be checked.
|
||||
/// \param[in] prim The Primitive value, nullptr means match any Primitive.
|
||||
///
|
||||
/// \return True if the node is cnode and the first input is the given Primitive, false otherwise.
|
||||
MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr);
|
||||
|
||||
/// \brief Check whether the given node is a ValueNode with the given Primitive.
|
||||
///
|
||||
/// \param[in] node The given node to be checked.
|
||||
/// \param[in] prim The Primitive value.
|
||||
///
|
||||
/// \return True if the given node is a ValueNode with the given Primitive, false otherwise.
|
||||
MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim);
|
||||
|
||||
/// \brief Check if a node is a data node.
|
||||
/// Some nodes may be used internally to pass some non-data states, those nodes are not data nodes.
|
||||
///
|
||||
/// \param[in] node The node to be checked.
|
||||
///
|
||||
/// \return True if the node is a data node, false otherwise.
|
||||
MIND_API bool IsDataNode(const AnfNodePtr &node);
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_ANF_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
|
||||
|
||||
#include <vector>
|
||||
#include "mindapi/base/shared_ptr.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class AnfNode;
|
||||
using AnfNodePtr = SharedPtr<AnfNode>;
|
||||
using AnfNodePtrList = std::vector<AnfNodePtr>;
|
||||
|
||||
class Value;
|
||||
using ValuePtr = SharedPtr<Value>;
|
||||
|
||||
class Primitive;
|
||||
using PrimitivePtr = SharedPtr<Primitive>;
|
||||
|
||||
class Type;
|
||||
using TypePtr = SharedPtr<Type>;
|
||||
|
||||
class AbstractBase;
|
||||
using AbstractBasePtr = SharedPtr<AbstractBase>;
|
||||
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
|
||||
|
||||
class Shape;
|
||||
using ShapePtr = SharedPtr<Shape>;
|
||||
|
||||
class FuncGraph;
|
||||
using FuncGraphPtr = SharedPtr<FuncGraph>;
|
||||
|
||||
class FuncGraphManager;
|
||||
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
|
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "mindapi/ir/anf.h"
|
||||
#include "mindapi/ir/primitive.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
#include "mindapi/ir/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
class FuncGraphManager;
|
||||
}
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief FuncGraph defines interface for a function graph.
|
||||
class MIND_API FuncGraph : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(FuncGraph);
|
||||
|
||||
/// \brief Get the input parameters.
|
||||
///
|
||||
/// \return Input parameters of this graph.
|
||||
std::vector<AnfNodePtr> get_inputs() const;
|
||||
|
||||
/// \brief Get all parameters.
|
||||
///
|
||||
/// \return All parameters of this graph.
|
||||
std::vector<AnfNodePtr> parameters() const;
|
||||
|
||||
/// \brief Adds a parameter to this graph.
|
||||
///
|
||||
/// \param[in] p The parameter to be added.
|
||||
void add_parameter(const ParameterPtr &p);
|
||||
|
||||
/// \brief Adds a new parameter to this graph.
|
||||
///
|
||||
/// \return The new added parameter.
|
||||
ParameterPtr add_parameter();
|
||||
|
||||
/// \brief Get the output node.
|
||||
///
|
||||
/// \return The output node, nullptr if output not set.
|
||||
AnfNodePtr output() const;
|
||||
|
||||
/// \brief Get the return CNode.
|
||||
///
|
||||
/// \return The return CNode, nullptr if no return node.
|
||||
CNodePtr get_return() const;
|
||||
|
||||
/// \brief Set the output node.
|
||||
///
|
||||
/// \param[in] value The output node to be set.
|
||||
/// \param[in] force_new_ret If true, a new return node is always created.
|
||||
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
|
||||
|
||||
/// \brief Set the return node.
|
||||
///
|
||||
/// \param[in] cnode The return CNode to be set.
|
||||
void set_return(const CNodePtr &cnode);
|
||||
|
||||
/// \brief Creates a new CNode in this graph.
|
||||
///
|
||||
/// \param[in] inputs The input nodes of the new CNode.
|
||||
///
|
||||
/// \return The created CNode.
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
|
||||
|
||||
/// \brief Creates a new primitive CNode in this graph.
|
||||
///
|
||||
/// \param[in] primitive The primitive of the new CNode.
|
||||
/// \param[in] prim_inputs The argument inputs of the primitive CNode.
|
||||
///
|
||||
/// \return The created primitive CNode.
|
||||
CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
|
||||
|
||||
/// \brief Get all nodes in this graph.
|
||||
///
|
||||
/// \return All nodes in this graph.
|
||||
std::vector<AnfNodePtr> nodes() const;
|
||||
|
||||
/// \brief Check whether an attribute is set for this graph.
|
||||
///
|
||||
/// \param[in] key The attribute key (name).
|
||||
///
|
||||
/// \return True if the attribute with the given key is set, false otherwise.
|
||||
bool has_attr(const std::string &key) const;
|
||||
|
||||
/// \brief Get an attribute value by its key.
|
||||
///
|
||||
/// \param[in] key The attribute key (name).
|
||||
///
|
||||
/// \return The attribute value for the given key, nullptr if attribute not found.
|
||||
ValuePtr get_attr(const std::string &key) const;
|
||||
|
||||
/// \brief Set an attribute value.
|
||||
///
|
||||
/// \param[in] key The attribute key (name).
|
||||
/// \param[in] value The attribute value.
|
||||
void set_attr(const std::string &key, const ValuePtr &value);
|
||||
|
||||
/// \brief Get the manager for this graph.
|
||||
///
|
||||
/// \return The manager of this graph, nullptr if not set.
|
||||
FuncGraphManagerPtr manager() const;
|
||||
|
||||
/// \brief Creates an empty function graph.
|
||||
///
|
||||
/// \return The created function graph.
|
||||
static FuncGraphPtr Create();
|
||||
|
||||
/// \brief Topological sort a graph from the given end node.
|
||||
///
|
||||
/// \param[in] node The end node of the graph to be sorted.
|
||||
///
|
||||
/// \return The sorted nodes.
|
||||
static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node);
|
||||
};
|
||||
|
||||
/// \brief FuncGraphManager defines interface for function graph management.
|
||||
class MIND_API FuncGraphManager {
|
||||
public:
|
||||
/// \brief Create FuncGraphManager with the given implementor object.
|
||||
///
|
||||
/// \param[in] impl The pointer to the implementor object.
|
||||
explicit FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl);
|
||||
|
||||
/// \brief Get the shared_ptr to the underly implementation object.
|
||||
///
|
||||
/// \return The shared_ptr to the underly implementation object.
|
||||
const std::shared_ptr<mindspore::FuncGraphManager> &impl() const { return impl_; }
|
||||
|
||||
/// \brief Replace an old node with a new node, related edges are all updated.
|
||||
///
|
||||
/// \param[in] old_node The old node to be replaced.
|
||||
/// \param[in] new_node The new node that replace the old one.
|
||||
///
|
||||
/// \return True if the node is successfully replaced, false otherwise.
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
|
||||
/// \brief Change an existed edge by replace its input node.
|
||||
///
|
||||
/// \param[in] node The output node of the edge.
|
||||
/// \param[in] index The input index in output node.
|
||||
/// \param[in] value The new input node of the edge.
|
||||
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
|
||||
|
||||
/// \brief Adds a new edge between the given two nodes.
|
||||
///
|
||||
/// \param[in] node The output node of the edge.
|
||||
/// \param[in] value The input node of the edge.
|
||||
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
|
||||
|
||||
/// \brief Find users of the given node.
|
||||
///
|
||||
/// \param[in] node The node.
|
||||
///
|
||||
/// \return Users of the given node, empty if user not found.
|
||||
std::vector<std::pair<AnfNodePtr, int>> GetUsers(const AnfNodePtr &node) const;
|
||||
|
||||
/// \brief Manage the give function graph.
|
||||
///
|
||||
/// \param[in] func_graph The function graph to be managed.
|
||||
/// \param[in] manage If true, the created manager will be set in the graph.
|
||||
///
|
||||
/// \return The manager that manages the given function graph.
|
||||
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
|
||||
|
||||
private:
|
||||
const std::shared_ptr<mindspore::FuncGraphManager> impl_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief Primitive defines a primitive operator.
|
||||
class MIND_API Primitive : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Primitive);
|
||||
|
||||
/// \brief Create primitive with the given name.
|
||||
///
|
||||
/// \param[in] name The primitive name.
|
||||
explicit Primitive(const std::string &name);
|
||||
|
||||
/// \brief Get name of the primitive.
|
||||
///
|
||||
/// \return The name of primitive.
|
||||
const std::string &name() const;
|
||||
|
||||
/// \brief Add attribute to primitive.
|
||||
///
|
||||
/// \param[in] name The attribute name.
|
||||
/// \param[in] attr The attribute value.
|
||||
/// \return The primitive to which attribute has been added.
|
||||
Primitive &AddAttr(const std::string &name, const ValuePtr &attr);
|
||||
|
||||
/// \brief Add attributes by using a map, all elements of the map will be added to this primitive.
|
||||
///
|
||||
/// \param[in] attrs The attribute map needs to be added in the primitive attribute.
|
||||
/// \return The primitive to which attribute has been added.
|
||||
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs);
|
||||
|
||||
/// \brief Erase attribute to the primitive attribute map.
|
||||
///
|
||||
/// \param[in] name The attribute name.
|
||||
void EraseAttr(const std::string &name);
|
||||
|
||||
/// \brief Get attribute value by name.
|
||||
///
|
||||
/// \param[in] name the attribute name.
|
||||
/// \return The value of the attribute, null if attribute name not found.
|
||||
ValuePtr GetAttr(const std::string &name) const;
|
||||
|
||||
/// \brief Check If Primitive has an attribute with then given name.
|
||||
///
|
||||
/// \param[in] name The attribute name.
|
||||
/// \return True if there is an attribute with the given name, otherwise false.
|
||||
bool HasAttr(const std::string &name) const;
|
||||
|
||||
/// \brief Get all attributes of this primitive as a map.
|
||||
///
|
||||
/// \return The attribute map of this primitive.
|
||||
std::unordered_map<std::string, ValuePtr> attrs() const;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_
|
||||
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/base/shape_vector.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief Shape defines dimensions of a tensor.
|
||||
class MIND_API Shape : public Base {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Shape);
|
||||
|
||||
/// \brief Get the shape dimensions.
|
||||
///
|
||||
/// \return The shape dimensions.
|
||||
const ShapeVector &shape() const;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/base/shape_vector.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief Tensor represents a multi-dimensional array of elements.
|
||||
class MIND_API Tensor : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Tensor);
|
||||
|
||||
/// \brief Create a lazy allocated tensor.
|
||||
///
|
||||
/// \param[in] data_type [TypeId] Data type of the tensor.
|
||||
/// \param[in] shape The shape represented by ShapeVector of the tensor.
|
||||
Tensor(TypeId data_type, const ShapeVector &shape);
|
||||
|
||||
/// \brief Create a tensor with input data buffer.
|
||||
///
|
||||
/// \param[in] data_type [TypeId] Data type of the tensor.
|
||||
/// \param[in] shape The shape represented by ShapeVector of the tensor.
|
||||
/// \param[in] data The input data to be copied into tensor.
|
||||
/// \param[in] data_len The length of data in bytes.
|
||||
Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);
|
||||
|
||||
/// \brief Get the shape of the tensor.
|
||||
/// The shape of a tensor is stored in a vector<int64_t>. Each element of the
|
||||
/// vector represents the size of a dimension of the tensor. The order of each
|
||||
/// element in the vector is the same as the the dimension's order it represents.
|
||||
///
|
||||
/// \return A vector<int64_t> which represents the shape of the tensor.
|
||||
const ShapeVector &shape() const;
|
||||
|
||||
/// \brief Set the shape of tensor.
|
||||
///
|
||||
/// \param[in] shape The shape to be set.
|
||||
void set_shape(const ShapeVector &shape);
|
||||
|
||||
/// \brief Get the data type of the tensor.
|
||||
///
|
||||
/// \return The data type of the tensor.
|
||||
TypeId data_type() const;
|
||||
|
||||
/// \brief Set the data type of the tensor.
|
||||
///
|
||||
/// \param[in] data_type The data type to be set.
|
||||
void set_data_type(const TypeId data_type);
|
||||
|
||||
/// \brief Get The pointer to the underlying memory block for data storage.
|
||||
///
|
||||
/// \return The pointer to the underlying data.
|
||||
const void *data() const;
|
||||
|
||||
/// \brief Get The pointer to the underlying memory block for data storage.
|
||||
///
|
||||
/// \return The pointer to the underlying data.
|
||||
void *data();
|
||||
|
||||
/// \brief Get tensor data size.
|
||||
///
|
||||
/// \return The total number of elements in the tensor.
|
||||
int DataSize() const;
|
||||
|
||||
/// \brief Get tensor data size in bytes.
|
||||
///
|
||||
/// \return The total number of bytes for the tensor data.
|
||||
std::size_t Size() const;
|
||||
};
|
||||
|
||||
using TensorPtr = SharedPtr<Tensor>;
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_TYPE_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_TYPE_H_
|
||||
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
/// \brief Type defines the type of a value.
|
||||
class MIND_API Type : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Type);
|
||||
|
||||
/// \brief Get the id of the Type object.
|
||||
///
|
||||
/// \return The id of the Type object.
|
||||
TypeId type_id() const;
|
||||
|
||||
/// \brief Get the number type of the Type object.
|
||||
///
|
||||
/// \return The number type of this Type object, kTypeUnknown if this is not a number type.
|
||||
TypeId number_type() const;
|
||||
|
||||
/// \brief Get the Type according to a TypeId.
|
||||
///
|
||||
/// \param[in] id The id of the type.
|
||||
///
|
||||
/// \return The pointer to the Type.
|
||||
static TypePtr GetType(TypeId id);
|
||||
|
||||
/// \brief Get data size in bytes for the type according to a TypeId.
|
||||
///
|
||||
/// \param[in] id The id of the type.
|
||||
///
|
||||
/// \return The data size in bytes for the Type.
|
||||
static size_t GetSize(TypeId id);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_TYPE_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_UTILS_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_UTILS_H_
|
||||
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/base/shared_ptr.h"
|
||||
#include "mindapi/base/type_traits.h"
|
||||
#include "mindapi/ir/anf.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
#include "mindapi/ir/func_graph.h"
|
||||
|
||||
namespace mindspore::api::utils {
|
||||
/// \brief Check whether the given object is an instance of the given class.
|
||||
///
|
||||
/// \param[in] ptr The pointer to the given object.
|
||||
///
|
||||
/// \return True if the pointer is not null and the object is an instance of the given class, false otherwise.
|
||||
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>>
|
||||
bool isa(const BasePtr &ptr) {
|
||||
if (ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return ptr->isa<T>();
|
||||
}
|
||||
|
||||
/// \brief Cast the given object pointer to a pointer with the given class.
|
||||
///
|
||||
/// \param[in] ptr The pointer to the object to casted.
|
||||
///
|
||||
/// \return A non-null pointer if the input pointer is not null and cast success, nullptr otherwise.
|
||||
template <typename T, typename = typename std::enable_if_t<is_wrapper_ptr<T>::value, T>>
|
||||
T cast(const BasePtr &ptr) {
|
||||
if (ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return ptr->cast<T>();
|
||||
}
|
||||
|
||||
/// \brief Make a copy from the given function graph.
|
||||
///
|
||||
/// \param[in] func_graph The graph to be cloned.
|
||||
///
|
||||
/// \return The cloned graph.
|
||||
MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
/// \brief Get pad mode id from a value holds the pad mode name or id.
|
||||
///
|
||||
/// \param[in] value The value holds the pad mode name or id.
|
||||
/// \param[in] is_upper Indicates whether the name is uppercase or lowercase, default is false for lowercase.
|
||||
///
|
||||
/// \return The pad mode id.
|
||||
MIND_API int64_t GetPadMode(const ValuePtr &value, bool is_upper = false);
|
||||
} // namespace mindspore::api::utils
|
||||
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_UTILS_H_
|
|
@ -0,0 +1,270 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include "mindapi/base/base.h"
|
||||
#include "mindapi/ir/common.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
template <typename T>
|
||||
struct ImmTrait {};
|
||||
|
||||
#define MIND_API_IMM_TRAIT(typeimm, prototype) \
|
||||
template <> \
|
||||
struct ImmTrait<prototype> { \
|
||||
using type = SharedPtr<typeimm>; \
|
||||
}
|
||||
|
||||
/// \brief Value represents a value in expression.
|
||||
class MIND_API Value : public Base {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Value);
|
||||
|
||||
/// \brief Get the type of this Value.
|
||||
///
|
||||
/// \return The type.
|
||||
TypePtr type() const;
|
||||
|
||||
/// \brief Get the abstract of this Value.
|
||||
///
|
||||
/// \return Abstract of this Value.
|
||||
AbstractBasePtr ToAbstract() const;
|
||||
};
|
||||
|
||||
/// \brief ValueSequence represents a sequence of values.
|
||||
class MIND_API ValueSequence : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ValueSequence);
|
||||
|
||||
/// \brief Get the size of this ValueSequence.
|
||||
///
|
||||
/// \return The size as the number of elements.
|
||||
std::size_t size() const;
|
||||
|
||||
/// \brief Get the list of values in this ValueSequence.
|
||||
///
|
||||
/// \return The list of element values.
|
||||
std::vector<ValuePtr> value() const;
|
||||
};
|
||||
|
||||
using ValueSequencePtr = SharedPtr<ValueSequence>;
|
||||
|
||||
/// \brief ValueTuple represents a value tuple.
|
||||
class MIND_API ValueTuple : public ValueSequence {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ValueTuple);
|
||||
|
||||
/// \brief Constructor of ValueTuple.
|
||||
///
|
||||
/// \param[in] elements The elements of the tuple.
|
||||
explicit ValueTuple(const std::vector<ValuePtr> &elements);
|
||||
};
|
||||
|
||||
using ValueTuplePtr = SharedPtr<ValueTuple>;
|
||||
|
||||
/// \brief StringImm defines a Value whose type is string.
|
||||
class MIND_API StringImm : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(StringImm);
|
||||
|
||||
/// \brief Create StringImm with the given string.
|
||||
///
|
||||
/// \param[in] str The given string value.
|
||||
explicit StringImm(const std::string &str);
|
||||
|
||||
/// \brief Get the string value of this StringImm.
|
||||
///
|
||||
/// \return The string value of this StringImm.
|
||||
const std::string &value() const;
|
||||
};
|
||||
|
||||
using StringImmPtr = SharedPtr<StringImm>;
|
||||
|
||||
MIND_API_IMM_TRAIT(StringImm, std::string);
|
||||
|
||||
/// \beief Scalar defines interface for scalar data.
|
||||
class MIND_API Scalar : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Scalar);
|
||||
};
|
||||
|
||||
/// \beief BoolImm defines interface for bool data.
|
||||
class MIND_API BoolImm : public Scalar {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(BoolImm);
|
||||
|
||||
/// \brief Create BoolImm with the given bool value.
|
||||
///
|
||||
/// \param[in] b The given bool value.
|
||||
explicit BoolImm(bool b);
|
||||
|
||||
/// \brief Get the bool value of this BoolImm.
|
||||
///
|
||||
/// \return The bool value of this BoolImm.
|
||||
bool value() const;
|
||||
};
|
||||
|
||||
using BoolImmPtr = SharedPtr<BoolImm>;
|
||||
|
||||
MIND_API_IMM_TRAIT(BoolImm, bool);
|
||||
|
||||
/// \beief IntegerImm defines interface for integer data.
|
||||
class MIND_API IntegerImm : public Scalar {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(IntegerImm);
|
||||
};
|
||||
|
||||
/// \beief Int64Imm defines interface for int64 data.
|
||||
class MIND_API Int64Imm : public IntegerImm {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Int64Imm);
|
||||
|
||||
/// \brief Create Int64Imm with the given int64 value.
|
||||
///
|
||||
/// \param[in] value The given bool value.
|
||||
explicit Int64Imm(int64_t value);
|
||||
|
||||
/// \brief Get the int64 value of this Int64Imm.
|
||||
///
|
||||
/// \return The int64 value of this Int64Imm.
|
||||
int64_t value() const;
|
||||
};
|
||||
|
||||
using Int64ImmPtr = SharedPtr<Int64Imm>;
|
||||
|
||||
MIND_API_IMM_TRAIT(Int64Imm, int64_t);
|
||||
|
||||
/// \beief FloatImm defines interface for float data.
|
||||
class MIND_API FloatImm : public Scalar {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(FloatImm);
|
||||
};
|
||||
|
||||
/// \beief FP32Imm defines interface for float32 data.
|
||||
class MIND_API FP32Imm : public FloatImm {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(FP32Imm);
|
||||
|
||||
/// \brief Create FP32Imm with the given float value.
|
||||
///
|
||||
/// \param[in] value The given float value.
|
||||
explicit FP32Imm(float value);
|
||||
|
||||
/// \brief Get the float value of this FP32Imm.
|
||||
///
|
||||
/// \return The float value of this FP32Imm.
|
||||
float value() const;
|
||||
};
|
||||
|
||||
using FP32ImmPtr = SharedPtr<FP32Imm>;
|
||||
|
||||
MIND_API_IMM_TRAIT(FP32Imm, float);
|
||||
|
||||
// === Utility functions for Value === //
|
||||
|
||||
/// \brief brief Create a Value object from a primitive type value.
|
||||
///
|
||||
/// \param[in] v The primitive type value.
|
||||
///
|
||||
/// \return The created Value object with the given primitive type value.
|
||||
template <typename T, typename U = typename ImmTrait<T>::type::element_type>
|
||||
inline ValuePtr MakeValue(T v) {
|
||||
return MakeShared<U>(v);
|
||||
}
|
||||
|
||||
/// \brief brief Create a StringImm Value object from a C string.
|
||||
///
|
||||
/// \param[in] s The C string.
|
||||
///
|
||||
/// \return The created StringImm Value object.
|
||||
inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }
|
||||
|
||||
/// \brief brief Create a Int64Imm Value object from a int value.
|
||||
///
|
||||
/// \param[in] i The int value.
|
||||
///
|
||||
/// \return The created Int64Imm Value object.
|
||||
inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }
|
||||
|
||||
/// \brief brief Create a ValueSequence object from a vector of values.
|
||||
///
|
||||
/// \param[in] values The vector of values.
|
||||
///
|
||||
/// \return The created ValueSequence object.
|
||||
inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }
|
||||
|
||||
/// \brief Create a ValueSequence object from a vector of primitive type values.
|
||||
///
|
||||
/// \param[in] values The vector of primitive values.
|
||||
///
|
||||
/// \return The created ValueSequence object.
|
||||
template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
|
||||
inline ValuePtr MakeValue(const T &values) {
|
||||
std::vector<ValuePtr> value_vector;
|
||||
value_vector.reserve(values.size());
|
||||
for (auto &value : values) {
|
||||
value_vector.emplace_back(MakeValue(value));
|
||||
}
|
||||
return MakeShared<ValueTuple>(value_vector);
|
||||
}
|
||||
|
||||
/// \brief brief Get primitive type value from a Value object.
|
||||
///
|
||||
/// \param[in] value The pointer to the Value object.
|
||||
///
|
||||
/// \return The primitive type value of the Value object.
|
||||
template <typename T, typename U = typename ImmTrait<T>::type>
|
||||
inline T GetValue(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
return T();
|
||||
}
|
||||
U imm = value->cast<U>();
|
||||
if (imm == nullptr) {
|
||||
return T();
|
||||
}
|
||||
return imm->value();
|
||||
}
|
||||
|
||||
/// \brief brief Get primitive element values from a ValueSequeue object.
|
||||
///
|
||||
/// \param[in] value The pointer to the ValueSequeue object.
|
||||
///
|
||||
/// \return The primitive type values as a vector.
|
||||
template <typename T, typename S = typename std::decay_t<T>,
|
||||
typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
|
||||
inline std::vector<U> GetValue(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto seq = value->cast<ValueSequencePtr>();
|
||||
if (seq == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto elements = seq->value();
|
||||
std::vector<U> result;
|
||||
result.reserve(elements.size());
|
||||
for (auto &e : elements) {
|
||||
result.emplace_back(GetValue<U>(e));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/abstract.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using TypeImpl = mindspore::Type;
|
||||
using ValueImpl = mindspore::Value;
|
||||
using AbstractBaseImpl = mindspore::abstract::AbstractBase;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractBase, AbstractBaseImpl, Base);
|
||||
|
||||
AbstractBasePtr AbstractBase::Clone() const {
|
||||
auto abs = ToRef<AbstractBaseImpl>(impl_).Clone();
|
||||
return ToWrapper<AbstractBase>(abs);
|
||||
}
|
||||
|
||||
TypePtr AbstractBase::type() const {
|
||||
auto t = ToRef<AbstractBaseImpl>(impl_).BuildType();
|
||||
return ToWrapper<Type>(t);
|
||||
}
|
||||
|
||||
ValuePtr AbstractBase::value() const {
|
||||
auto v = ToRef<AbstractBaseImpl>(impl_).BuildValue();
|
||||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
void AbstractBase::set_type(const TypePtr &type) {
|
||||
auto type_impl = ToImpl<TypeImpl>(type);
|
||||
ToRef<AbstractBaseImpl>(impl_).set_type(type_impl);
|
||||
}
|
||||
|
||||
void AbstractBase::set_value(const ValuePtr &value) {
|
||||
auto value_impl = ToImpl<ValueImpl>(value);
|
||||
ToRef<AbstractBaseImpl>(impl_).set_value(value_impl);
|
||||
}
|
||||
|
||||
using AbstractTensorImpl = mindspore::abstract::AbstractTensor;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractTensor, AbstractTensorImpl, AbstractBase);
|
||||
|
||||
AbstractTensor::AbstractTensor(TypeId type, const ShapeVector &shape)
|
||||
: AbstractBase(std::make_shared<AbstractTensorImpl>(mindspore::TypeIdToType(type), shape)) {}
|
||||
|
||||
AbstractBasePtr AbstractTensor::element() const {
|
||||
auto abs = ToRef<AbstractTensorImpl>(impl_).element();
|
||||
return ToWrapper<AbstractBase>(abs);
|
||||
}
|
||||
|
||||
ShapePtr AbstractTensor::shape() const {
|
||||
auto s = ToRef<AbstractTensorImpl>(impl_).shape();
|
||||
return ToWrapper<Shape>(s);
|
||||
}
|
||||
|
||||
using AbstractSequenceImpl = mindspore::abstract::AbstractSequeue;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractSequence, AbstractSequenceImpl, AbstractBase);
|
||||
|
||||
AbstractBasePtrList AbstractSequence::elements() const {
|
||||
auto &impl_elements = ToRef<AbstractSequenceImpl>(impl_).elements();
|
||||
return ToWrapperVector<AbstractBase>(impl_elements);
|
||||
}
|
||||
|
||||
using AbstractTupleImpl = mindspore::abstract::AbstractTuple;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractTuple, AbstractTupleImpl, AbstractSequence);
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,135 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/anf.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using ValueImpl = mindspore::Value;
|
||||
using AnfNodeImpl = mindspore::AnfNode;
|
||||
using PrimitiveImpl = mindspore::Primitive;
|
||||
using AbstractBaseImpl = mindspore::abstract::AbstractBase;
|
||||
|
||||
MIND_API_BASE_IMPL(AnfNode, AnfNodeImpl, Base);
|
||||
|
||||
std::string AnfNode::fullname_with_scope() const { return ToRef<AnfNodeImpl>(impl_).fullname_with_scope(); }
|
||||
|
||||
AbstractBasePtr AnfNode::abstract() const {
|
||||
const auto &abs = ToRef<AnfNodeImpl>(impl_).abstract();
|
||||
return ToWrapper<AbstractBase>(abs);
|
||||
}
|
||||
|
||||
void AnfNode::set_abstract(const AbstractBasePtr &abs) {
|
||||
ToRef<AnfNodeImpl>(impl_).set_abstract(ToImpl<AbstractBaseImpl>(abs));
|
||||
}
|
||||
|
||||
using CNodeImpl = mindspore::CNode;
|
||||
|
||||
MIND_API_BASE_IMPL(CNode, CNodeImpl, AnfNode);
|
||||
|
||||
size_t CNode::size() const { return ToRef<CNodeImpl>(impl_).size(); }
|
||||
|
||||
AnfNodePtr CNode::input(size_t i) const {
|
||||
auto &input = ToRef<CNodeImpl>(impl_).input(i);
|
||||
return ToWrapper<AnfNode>(input);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CNode::inputs() const {
|
||||
auto &impl_inputs = ToRef<CNodeImpl>(impl_).inputs();
|
||||
return ToWrapperVector<AnfNode>(impl_inputs);
|
||||
}
|
||||
|
||||
void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) {
|
||||
auto impl_inputs = ToImplVector<AnfNodeImpl>(inputs);
|
||||
ToRef<CNodeImpl>(impl_).set_inputs(impl_inputs);
|
||||
}
|
||||
|
||||
void CNode::add_input(const AnfNodePtr &input) {
|
||||
auto impl_input = ToImpl<AnfNodeImpl>(input);
|
||||
MS_EXCEPTION_IF_NULL(impl_input);
|
||||
ToRef<CNodeImpl>(impl_).add_input(impl_input);
|
||||
}
|
||||
|
||||
void CNode::set_fullname_with_scope(const std::string &full_name) {
|
||||
ToRef<CNodeImpl>(impl_).set_fullname_with_scope(full_name);
|
||||
}
|
||||
|
||||
void CNode::AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||
auto impl_attr = ToImpl<ValueImpl>(attr);
|
||||
MS_EXCEPTION_IF_NULL(impl_attr);
|
||||
ToRef<CNodeImpl>(impl_).AddAttr(name, impl_attr);
|
||||
}
|
||||
|
||||
void CNode::EraseAttr(const std::string &name) { ToRef<CNodeImpl>(impl_).EraseAttr(name); }
|
||||
|
||||
ValuePtr CNode::GetAttr(const std::string &name) const {
|
||||
auto v = ToRef<CNodeImpl>(impl_).GetAttr(name);
|
||||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
using ParameterImpl = mindspore::Parameter;
|
||||
|
||||
MIND_API_BASE_IMPL(Parameter, ParameterImpl, AnfNode);
|
||||
|
||||
std::string Parameter::name() const { return ToRef<ParameterImpl>(impl_).name(); }
|
||||
|
||||
void Parameter::set_name(const std::string &name) { ToRef<ParameterImpl>(impl_).set_name(name); }
|
||||
|
||||
bool Parameter::has_default() const { return ToRef<ParameterImpl>(impl_).has_default(); }
|
||||
|
||||
void Parameter::set_default_param(const ValuePtr ¶m) {
|
||||
auto v = ToImpl<ValueImpl>(param);
|
||||
ToRef<ParameterImpl>(impl_).set_default_param(v);
|
||||
}
|
||||
|
||||
ValuePtr Parameter::default_param() const {
|
||||
auto v = ToRef<ParameterImpl>(impl_).default_param();
|
||||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
using ValueNodeImpl = mindspore::ValueNode;
|
||||
|
||||
MIND_API_BASE_IMPL(ValueNode, ValueNodeImpl, AnfNode);
|
||||
|
||||
ValueNode::ValueNode(const ValuePtr &value) : AnfNode(std::make_shared<ValueNodeImpl>(ToImpl<ValueImpl>(value))) {}
|
||||
|
||||
ValuePtr ValueNode::value() const {
|
||||
auto v = ToRef<ValueNodeImpl>(impl_).value();
|
||||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim) {
|
||||
auto node_impl = ToImpl<AnfNodeImpl>(node);
|
||||
auto prim_impl = ToImpl<PrimitiveImpl>(prim);
|
||||
return mindspore::IsPrimitiveCNode(node_impl, prim_impl);
|
||||
}
|
||||
|
||||
bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim) {
|
||||
auto node_impl = ToImpl<AnfNodeImpl>(node);
|
||||
auto prim_impl = ToImpl<PrimitiveImpl>(prim);
|
||||
return mindspore::IsPrimitive(node_impl, prim_impl);
|
||||
}
|
||||
|
||||
bool IsDataNode(const AnfNodePtr &node) {
|
||||
auto node_impl = ToImpl<AnfNodeImpl>(node);
|
||||
// We assume that node with monad abstract is not a data node.
|
||||
return !HasAbstractMonad(node_impl);
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/base/base.h"
|
||||
#include "base/base.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Base::Base(const std::shared_ptr<mindspore::Base> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl_); }
|
||||
|
||||
uint32_t Base::ClassId() { return mindspore::Base::kTypeId; }
|
||||
|
||||
bool Base::IsFromClassId(uint32_t class_id) const { return impl_->IsFromTypeId(class_id); }
|
||||
|
||||
std::string Base::ToString() const { return impl_->ToString(); }
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "mindapi/ir/func_graph.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#define USE_DEPRECATED_API
|
||||
#include "ir/anf.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/graph_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using ValueImpl = mindspore::Value;
|
||||
using AnfNodeImpl = mindspore::AnfNode;
|
||||
using CNodeImpl = mindspore::CNode;
|
||||
using PrimitiveImpl = mindspore::Primitive;
|
||||
using ParameterImpl = mindspore::Parameter;
|
||||
using FuncGraphImpl = mindspore::FuncGraph;
|
||||
using FuncGraphManagerImpl = mindspore::FuncGraphManager;
|
||||
|
||||
MIND_API_BASE_IMPL(FuncGraph, FuncGraphImpl, Value);
|
||||
|
||||
std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
|
||||
auto &inputs = ToRef<FuncGraphImpl>(impl_).get_inputs();
|
||||
return ToWrapperVector<AnfNode>(inputs);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FuncGraph::parameters() const {
|
||||
auto ¶ms = ToRef<FuncGraphImpl>(impl_).parameters();
|
||||
return ToWrapperVector<AnfNode>(params);
|
||||
}
|
||||
|
||||
void FuncGraph::add_parameter(const ParameterPtr &p) {
|
||||
auto param_impl = ToImpl<ParameterImpl>(p);
|
||||
ToRef<FuncGraphImpl>(impl_).add_parameter(param_impl);
|
||||
}
|
||||
|
||||
ParameterPtr FuncGraph::add_parameter() {
|
||||
auto param_impl = ToRef<FuncGraphImpl>(impl_).add_parameter();
|
||||
return ToWrapper<Parameter>(param_impl);
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraph::output() const {
|
||||
auto output = ToRef<FuncGraphImpl>(impl_).output();
|
||||
return ToWrapper<AnfNode>(output);
|
||||
}
|
||||
|
||||
CNodePtr FuncGraph::get_return() const {
|
||||
auto ret = ToRef<FuncGraphImpl>(impl_).get_return();
|
||||
return ToWrapper<CNode>(ret);
|
||||
}
|
||||
|
||||
void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
|
||||
auto output = ToImpl<AnfNodeImpl>(value);
|
||||
ToRef<FuncGraphImpl>(impl_).set_output(output);
|
||||
}
|
||||
|
||||
void FuncGraph::set_return(const CNodePtr &cnode) {
|
||||
auto cnode_impl = ToImpl<CNodeImpl>(cnode);
|
||||
ToRef<FuncGraphImpl>(impl_).set_return(cnode_impl);
|
||||
}
|
||||
|
||||
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
||||
auto inputs_impl = ToImplVector<AnfNodeImpl>(inputs);
|
||||
auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(std::move(inputs_impl));
|
||||
return ToWrapper<CNode>(cnode_impl);
|
||||
}
|
||||
|
||||
CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) {
|
||||
auto prim_impl = ToImpl<PrimitiveImpl>(primitive);
|
||||
auto prim_inputs_impl = ToImplVector<AnfNodeImpl>(prim_inputs);
|
||||
auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(prim_impl, prim_inputs_impl);
|
||||
return ToWrapper<CNode>(cnode_impl);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FuncGraph::nodes() const {
|
||||
auto &nodes = ToRef<FuncGraphImpl>(impl_).nodes();
|
||||
return ToWrapperVector<AnfNode>(nodes);
|
||||
}
|
||||
|
||||
bool FuncGraph::has_attr(const std::string &key) const { return ToRef<FuncGraphImpl>(impl_).has_attr(key); }
|
||||
|
||||
ValuePtr FuncGraph::get_attr(const std::string &key) const {
|
||||
auto v = ToRef<FuncGraphImpl>(impl_).get_attr(key);
|
||||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
void FuncGraph::set_attr(const std::string &key, const ValuePtr &value) {
|
||||
auto value_impl = ToImpl<ValueImpl>(value);
|
||||
ToRef<FuncGraphImpl>(impl_).set_attr(key, value_impl);
|
||||
}
|
||||
|
||||
FuncGraphManagerPtr FuncGraph::manager() const {
|
||||
auto manager = ToRef<FuncGraphImpl>(impl_).manager();
|
||||
if (manager == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return MakeShared<FuncGraphManager>(manager);
|
||||
}
|
||||
|
||||
FuncGraphPtr FuncGraph::Create() {
|
||||
auto fg = std::make_shared<FuncGraphImpl>();
|
||||
return ToWrapper<FuncGraph>(fg);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) {
|
||||
auto node_impl = ToImpl<AnfNodeImpl>(node);
|
||||
if (node_impl == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto sorted = mindspore::TopoSort(node_impl);
|
||||
return ToWrapperVector<AnfNode>(sorted);
|
||||
}
|
||||
|
||||
// FuncGraphManager is not derived from Base, we implement it directly.
|
||||
FuncGraphManager::FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl) : impl_(impl) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
}
|
||||
|
||||
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
return impl_->Replace(ToImpl<AnfNodeImpl>(old_node), ToImpl<AnfNodeImpl>(new_node));
|
||||
}
|
||||
|
||||
void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
|
||||
return impl_->SetEdge(ToImpl<AnfNodeImpl>(node), index, ToImpl<AnfNodeImpl>(value));
|
||||
}
|
||||
|
||||
void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
|
||||
return impl_->AddEdge(ToImpl<AnfNodeImpl>(node), ToImpl<AnfNodeImpl>(value));
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int>> FuncGraphManager::GetUsers(const AnfNodePtr &node) const {
|
||||
auto &node_users = impl_->node_users();
|
||||
auto iter = node_users.find(ToImpl<AnfNodeImpl>(node));
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
auto &users_impl = iter->second;
|
||||
std::vector<std::pair<AnfNodePtr, int>> users;
|
||||
users.reserve(users_impl.size());
|
||||
std::transform(users_impl.begin(), users_impl.end(), std::back_inserter(users),
|
||||
[](const auto &user) { return std::make_pair(ToWrapper<AnfNode>(user.first), user.second); });
|
||||
return users;
|
||||
}
|
||||
|
||||
FuncGraphManagerPtr FuncGraphManager::Manage(const FuncGraphPtr &func_graph, bool manage) {
|
||||
auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
|
||||
auto mgr_impl = mindspore::Manage(fg_impl, manage);
|
||||
if (mgr_impl == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return MakeShared<FuncGraphManager>(mgr_impl);
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
#include "mindapi/base/base.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
template <typename T, typename U>
|
||||
T &ToRef(const std::shared_ptr<U> &ptr) {
|
||||
return static_cast<T &>(*ptr);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename = typename std::enable_if_t<std::is_base_of_v<mindspore::Base, T>>,
|
||||
typename = typename std::enable_if_t<std::is_base_of_v<Base, U>>>
|
||||
std::shared_ptr<T> ToImpl(const SharedPtr<U> &wrapper) {
|
||||
if (wrapper == nullptr || wrapper->impl() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::dynamic_pointer_cast<T>(wrapper->impl());
|
||||
}
|
||||
|
||||
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>>>
|
||||
SharedPtr<T> ToWrapper(const std::shared_ptr<mindspore::Base> &impl) {
|
||||
if (impl == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return MakeShared<T>(impl);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
std::vector<std::shared_ptr<T>> ToImplVector(const U &wrapper_vector) {
|
||||
std::vector<std::shared_ptr<T>> impl_vector;
|
||||
impl_vector.reserve(wrapper_vector.size());
|
||||
for (auto &wrapper : wrapper_vector) {
|
||||
impl_vector.emplace_back(ToImpl<T>(wrapper));
|
||||
}
|
||||
return impl_vector;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
std::vector<SharedPtr<T>> ToWrapperVector(const U &impl_vector) {
|
||||
std::vector<SharedPtr<T>> wrapper_vector;
|
||||
wrapper_vector.reserve(impl_vector.size());
|
||||
for (auto &impl : impl_vector) {
|
||||
wrapper_vector.emplace_back(ToWrapper<T>(impl));
|
||||
}
|
||||
return wrapper_vector;
|
||||
}
|
||||
|
||||
#define MIND_API_BASE_IMPL(current_class, impl_class, base_class) \
|
||||
current_class::current_class(const std::shared_ptr<mindspore::Base> &impl) : base_class(impl) { \
|
||||
if (!impl_->isa<impl_class>()) { \
|
||||
MS_LOG(EXCEPTION) << "Wrong impl " << impl_->type_name() << " for " << #current_class; \
|
||||
} \
|
||||
} \
|
||||
uint32_t current_class::ClassId() { return impl_class::kTypeId; }
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define MIND_LOG_NO_MS_LOG
|
||||
#include "mindapi/base/logging.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
static MsLogLevel ToMsLogLevel(LogLevel level) {
|
||||
switch (level) {
|
||||
case LogLevel::DEBUG:
|
||||
return MsLogLevel::DEBUG;
|
||||
case LogLevel::INFO:
|
||||
return MsLogLevel::INFO;
|
||||
case LogLevel::WARNING:
|
||||
return MsLogLevel::WARNING;
|
||||
case LogLevel::ERROR:
|
||||
return MsLogLevel::ERROR;
|
||||
case LogLevel::EXCEPTION:
|
||||
return MsLogLevel::EXCEPTION;
|
||||
default:
|
||||
return MsLogLevel::EXCEPTION;
|
||||
}
|
||||
}
|
||||
|
||||
class LogWriterImpl {
|
||||
public:
|
||||
LogWriterImpl(LogLevel level, const char *file, int line, const char *func)
|
||||
: writer_(LocationInfo(file, line, func), ToMsLogLevel(level), SubModuleId::SM_API) {}
|
||||
|
||||
~LogWriterImpl() = default;
|
||||
|
||||
void Write(const LogStream &stream) const noexcept {
|
||||
mindspore::LogStream log_stream;
|
||||
log_stream << stream.stream_.rdbuf();
|
||||
writer_ < log_stream;
|
||||
}
|
||||
|
||||
void WriteAndThrow(const LogStream &stream) const __attribute__((noreturn)) {
|
||||
mindspore::LogStream log_stream;
|
||||
log_stream << stream.stream_.rdbuf();
|
||||
writer_ ^ log_stream;
|
||||
}
|
||||
|
||||
private:
|
||||
mindspore::LogWriter writer_;
|
||||
};
|
||||
|
||||
LogWriter::LogWriter(LogLevel level, const char *file, int line, const char *func)
|
||||
: impl_(std::make_unique<LogWriterImpl>(level, file, line, func)) {}
|
||||
|
||||
LogWriter::~LogWriter() = default;
|
||||
|
||||
void LogWriter::operator<(const LogStream &stream) const noexcept { impl_->Write(stream); }
|
||||
|
||||
void LogWriter::operator^(const LogStream &stream) const { impl_->WriteAndThrow(stream); }
|
||||
|
||||
bool LogWriter::IsEnabled(LogLevel level) {
|
||||
auto log_level = ToMsLogLevel(level);
|
||||
return IS_OUTPUT_ON(log_level);
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/primitive.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using ValueImpl = mindspore::Value;
|
||||
using PrimitiveImpl = mindspore::Primitive;
|
||||
|
||||
MIND_API_BASE_IMPL(Primitive, PrimitiveImpl, Value);
|
||||
|
||||
Primitive::Primitive(const std::string &name) : Value(std::make_shared<PrimitiveImpl>(name)) {}
|
||||
|
||||
const std::string &Primitive::name() const { return ToRef<PrimitiveImpl>(impl_).name(); }
|
||||
|
||||
Primitive &Primitive::AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||
auto value = ToImpl<ValueImpl>(attr);
|
||||
ToRef<PrimitiveImpl>(impl_).set_attr(name, value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Primitive &Primitive::SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
for (auto &attr : attrs) {
|
||||
auto value = ToImpl<ValueImpl>(attr.second);
|
||||
ToRef<PrimitiveImpl>(impl_).set_attr(attr.first, value);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Primitive::EraseAttr(const std::string &name) { ToRef<PrimitiveImpl>(impl_).EraseAttr(name); }
|
||||
|
||||
ValuePtr Primitive::GetAttr(const std::string &name) const {
|
||||
auto v = ToRef<PrimitiveImpl>(impl_).GetAttr(name);
|
||||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
bool Primitive::HasAttr(const std::string &name) const { return ToRef<PrimitiveImpl>(impl_).HasAttr(name); }
|
||||
|
||||
std::unordered_map<std::string, ValuePtr> Primitive::attrs() const {
|
||||
std::unordered_map<std::string, ValuePtr> attr_map;
|
||||
auto &impl_attrs = ToRef<PrimitiveImpl>(impl_).attrs();
|
||||
attr_map.reserve(impl_attrs.size());
|
||||
for (auto &attr : impl_attrs) {
|
||||
auto value = ToWrapper<Value>(attr.second);
|
||||
attr_map.emplace(attr.first, value);
|
||||
}
|
||||
return attr_map;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/shape.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/dshape.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using ShapeImpl = mindspore::abstract::Shape;
|
||||
|
||||
MIND_API_BASE_IMPL(Shape, ShapeImpl, Base);
|
||||
|
||||
const ShapeVector &Shape::shape() const { return ToRef<ShapeImpl>(impl_).shape(); }
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "mindapi/ir/tensor.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using TensorImpl = mindspore::tensor::Tensor;
|
||||
|
||||
MIND_API_BASE_IMPL(Tensor, TensorImpl, Value);
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : Value(std::make_shared<TensorImpl>(data_type, shape)) {}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
|
||||
: Value(std::make_shared<TensorImpl>(data_type, shape, data, data_len)) {}
|
||||
|
||||
const ShapeVector &Tensor::shape() const { return ToRef<TensorImpl>(impl_).shape(); }
|
||||
|
||||
void Tensor::set_shape(const ShapeVector &shape) { (void)ToRef<TensorImpl>(impl_).set_shape(shape); }
|
||||
|
||||
TypeId Tensor::data_type() const { return ToRef<TensorImpl>(impl_).data_type(); }
|
||||
|
||||
void Tensor::set_data_type(const TypeId data_type) { (void)ToRef<TensorImpl>(impl_).set_data_type(data_type); }
|
||||
|
||||
const void *Tensor::data() const { return ToRef<TensorImpl>(impl_).data_c(); }
|
||||
|
||||
void *Tensor::data() { return ToRef<TensorImpl>(impl_).data_c(); }
|
||||
|
||||
int Tensor::DataSize() const { return ToRef<TensorImpl>(impl_).DataSize(); }
|
||||
|
||||
size_t Tensor::Size() const { return ToRef<TensorImpl>(impl_).Size(); }
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/type.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "abstract/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using TypeImpl = mindspore::Type;
|
||||
|
||||
MIND_API_BASE_IMPL(Type, TypeImpl, Value);
|
||||
|
||||
TypeId Type::type_id() const { return ToRef<TypeImpl>(impl_).type_id(); }
|
||||
|
||||
TypeId Type::number_type() const { return ToRef<TypeImpl>(impl_).number_type(); }
|
||||
|
||||
TypePtr Type::GetType(TypeId id) {
|
||||
auto type_impl = mindspore::TypeIdToType(id);
|
||||
return ToWrapper<Type>(type_impl);
|
||||
}
|
||||
|
||||
size_t Type::GetSize(TypeId id) { return mindspore::abstract::TypeIdSize(id); }
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#define USE_DEPRECATED_API
|
||||
#include "ir/anf.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore::api::utils {
|
||||
using ValueImpl = mindspore::Value;
|
||||
using FuncGraphImpl = mindspore::FuncGraph;
|
||||
|
||||
MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
|
||||
Cloner cloner({fg_impl}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
|
||||
auto cloned_fg = cloner[fg_impl];
|
||||
return ToWrapper<api::FuncGraph>(cloned_fg);
|
||||
}
|
||||
|
||||
int64_t GetPadMode(const api::ValuePtr &value, bool is_upper) {
|
||||
int64_t result;
|
||||
auto value_impl = ToImpl<ValueImpl>(value);
|
||||
CheckAndConvertUtils::GetPadModEnumValue(value_impl, &result, is_upper);
|
||||
return result;
|
||||
}
|
||||
} // namespace mindspore::api::utils
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindapi/ir/value.h"
|
||||
#include "mindapi/ir/type.h"
|
||||
#include "mindapi/ir/abstract.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/scalar.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using ValueImpl = mindspore::Value;
|
||||
using ValueSequenceImpl = mindspore::ValueSequeue; // 'Sequeue' is typo.
|
||||
using ValueTupleImpl = mindspore::ValueTuple;
|
||||
using StringImmImpl = mindspore::StringImm;
|
||||
using ScalarImpl = mindspore::Scalar;
|
||||
using BoolImmImpl = mindspore::BoolImm;
|
||||
using IntegerImmImpl = mindspore::IntergerImm; // 'Interger' is typo.
|
||||
using Int64ImmImpl = mindspore::Int64Imm;
|
||||
using FloatImmImpl = mindspore::FloatImm;
|
||||
using FP32ImmImpl = mindspore::FP32Imm;
|
||||
|
||||
MIND_API_BASE_IMPL(Value, ValueImpl, Base);
|
||||
|
||||
TypePtr Value::type() const {
|
||||
auto t = ToRef<ValueImpl>(impl_).type();
|
||||
return ToWrapper<Type>(t);
|
||||
}
|
||||
|
||||
AbstractBasePtr Value::ToAbstract() const {
|
||||
auto abs = ToRef<ValueImpl>(impl_).ToAbstract();
|
||||
return ToWrapper<AbstractBase>(abs);
|
||||
}
|
||||
|
||||
MIND_API_BASE_IMPL(ValueSequence, ValueSequenceImpl, Value);
|
||||
|
||||
std::size_t ValueSequence::size() const { return ToRef<ValueSequenceImpl>(impl_).size(); }
|
||||
|
||||
std::vector<ValuePtr> ValueSequence::value() const {
|
||||
auto &elements = ToRef<ValueSequenceImpl>(impl_).value();
|
||||
return ToWrapperVector<Value>(elements);
|
||||
}
|
||||
|
||||
MIND_API_BASE_IMPL(ValueTuple, ValueTupleImpl, ValueSequence);
|
||||
|
||||
ValueTuple::ValueTuple(const std::vector<ValuePtr> &elements)
|
||||
: ValueSequence(std::make_shared<ValueTupleImpl>(ToImplVector<ValueImpl>(elements))) {}
|
||||
|
||||
MIND_API_BASE_IMPL(StringImm, StringImmImpl, Value);
|
||||
|
||||
StringImm::StringImm(const std::string &str) : Value(std::make_shared<StringImmImpl>(str)) {}
|
||||
|
||||
const std::string &StringImm::value() const { return ToRef<StringImmImpl>(impl_).value(); }
|
||||
|
||||
MIND_API_BASE_IMPL(Scalar, ScalarImpl, Value);
|
||||
|
||||
MIND_API_BASE_IMPL(BoolImm, BoolImmImpl, Scalar);
|
||||
|
||||
BoolImm::BoolImm(bool b) : Scalar(std::make_shared<BoolImmImpl>(b)) {}
|
||||
|
||||
bool BoolImm::value() const { return ToRef<BoolImmImpl>(impl_).value(); }
|
||||
|
||||
MIND_API_BASE_IMPL(IntegerImm, IntegerImmImpl, Scalar);
|
||||
|
||||
MIND_API_BASE_IMPL(Int64Imm, Int64ImmImpl, IntegerImm);
|
||||
|
||||
Int64Imm::Int64Imm(int64_t value) : IntegerImm(std::make_shared<Int64ImmImpl>(value)) {}
|
||||
|
||||
int64_t Int64Imm::value() const { return ToRef<Int64ImmImpl>(impl_).value(); }
|
||||
|
||||
MIND_API_BASE_IMPL(FloatImm, FloatImmImpl, Scalar);
|
||||
|
||||
MIND_API_BASE_IMPL(FP32Imm, FP32ImmImpl, FloatImm);
|
||||
|
||||
FP32Imm::FP32Imm(float value) : FloatImm(std::make_shared<FP32ImmImpl>(value)) {}
|
||||
|
||||
float FP32Imm::value() const { return ToRef<FP32ImmImpl>(impl_).value(); }
|
||||
} // namespace mindspore::api
|
|
@ -141,6 +141,7 @@ enum SubModuleId : int {
|
|||
SM_HCCL_ADPT, // Hccl Adapter
|
||||
SM_RUNTIME_FRAMEWORK, // Runtime framework
|
||||
SM_GE, // GraphEngine
|
||||
SM_API, // MindAPI
|
||||
NUM_SUBMODUES // number of submodules
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,7 +17,6 @@
|
|||
#ifndef MINDSPORE_SHAPE_UTILS_INFO_H_
|
||||
#define MINDSPORE_SHAPE_UTILS_INFO_H_
|
||||
|
||||
#include <vector>
|
||||
using ShapeVector = std::vector<int64_t>;
|
||||
#include "mindapi/base/shape_vector.h"
|
||||
|
||||
#endif // MINDSPORE_SHAPE_UTILS_INFO_H_
|
||||
|
|
|
@ -16,6 +16,8 @@ set(API_IR_HEADER
|
|||
${CORE_DIR}/api/ir/func_graph.h
|
||||
${CORE_DIR}/api/ir/func_graph_manager.h
|
||||
)
|
||||
file(GLOB MINDAPI_BASE_HEADER ${CORE_DIR}/mindapi/base/*.h)
|
||||
file(GLOB MINDAPI_IR_HEADER ${CORE_DIR}/mindapi/ir/*.h)
|
||||
set(BASE_HEADER
|
||||
${CORE_DIR}/base/base.h
|
||||
${CORE_DIR}/base/base_ref.h
|
||||
|
|
|
@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o)
|
|||
CFLAGS := -Ofast -std=c++17 \
|
||||
-I . \
|
||||
-I ./msl/runtime \
|
||||
-I ./msl/runtime/include \
|
||||
-I ./msl/runtime/minddata \
|
||||
-I ./msl/tools/third_party/flatbuffers/include
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o)
|
|||
CFLAGS := -Ofast -std=c++17 \
|
||||
-I . \
|
||||
-I ./msl/runtime \
|
||||
-I ./msl/runtime/include \
|
||||
-I ./msl/runtime/minddata \
|
||||
-I ./msl/tools/third_party/flatbuffers/include
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ INF_OBJ:=$(INF_SRC:.cc=.o)
|
|||
CFLAGS := -Ofast -std=c++17 \
|
||||
-I . \
|
||||
-I ./msl/runtime \
|
||||
-I ./msl/runtime/include \
|
||||
-I ./msl/runtime/minddata \
|
||||
-I ./msl/tools/third_party/flatbuffers/include
|
||||
|
||||
|
|
|
@ -289,6 +289,9 @@ if(APPLE)
|
|||
set(MINDSPORE_LITE_PUB_HDRS_IR_HDRS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type_id.h
|
||||
)
|
||||
set(MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/mindapi/base/type_id.h
|
||||
)
|
||||
add_library(mindspore-lite_static STATIC
|
||||
${LITE_SRC}
|
||||
${MINDSPORE_LITE_PUB_HDRS}
|
||||
|
@ -423,6 +426,9 @@ if(DEFINED ARCHS)
|
|||
FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_IR_HDRS})
|
||||
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/ir/dtype/)
|
||||
ENDFOREACH()
|
||||
FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS})
|
||||
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/mindapi/base/)
|
||||
ENDFOREACH()
|
||||
target_link_libraries(mindspore-lite_static)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -71,6 +71,7 @@ if(ENABLE_MINDDATA)
|
|||
./fl/*.cc
|
||||
./cxx_api/*.cc
|
||||
./tbe/*.cc
|
||||
./mindapi/*.cc
|
||||
)
|
||||
if(NOT ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
|
|
@ -0,0 +1,394 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include "common/common_test.h"
|
||||
#include "mindapi/base/logging.h"
|
||||
#include "mindapi/ir/func_graph.h"
|
||||
#include "mindapi/ir/tensor.h"
|
||||
#include "mindapi/ir/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class TestMindApi : public UT::Common {
|
||||
public:
|
||||
TestMindApi() = default;
|
||||
};
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test basic 'is()' 'cast()'
|
||||
/// Expectation: is/cast works correctly.
|
||||
TEST_F(TestMindApi, test_base_isa_cast) {
|
||||
auto value_node = MakeShared<ValueNode>(MakeValue(0));
|
||||
auto base = MakeShared<Base>(value_node->impl());
|
||||
ASSERT_TRUE(base->isa<Base>());
|
||||
ASSERT_TRUE(base->isa<AnfNode>());
|
||||
ASSERT_TRUE(base->isa<ValueNode>());
|
||||
ASSERT_FALSE(base->isa<AbstractBase>());
|
||||
auto anf_node = base->cast<AnfNodePtr>();
|
||||
ASSERT_TRUE(anf_node != nullptr);
|
||||
ASSERT_TRUE(anf_node->impl() == value_node->impl());
|
||||
ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test graph construction.
|
||||
/// Expectation: graph is constructed as expected.
|
||||
TEST_F(TestMindApi, test_graph_construction) {
|
||||
// fg(x) { return myprim(x, 1); }
|
||||
auto fg = FuncGraph::Create();
|
||||
auto x = fg->add_parameter();
|
||||
x->set_name("x");
|
||||
auto prim = MakeShared<Primitive>("myprim");
|
||||
auto prim_node = MakeShared<ValueNode>(prim);
|
||||
auto value_node = MakeShared<ValueNode>(MakeValue(1));
|
||||
auto cnode = fg->NewCNode({prim_node, x, value_node});
|
||||
fg->set_output(cnode);
|
||||
|
||||
// Now we check the graph.
|
||||
ASSERT_EQ(fg->parameters().size(), 1);
|
||||
ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>());
|
||||
ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x");
|
||||
|
||||
auto ret_node = fg->get_return();
|
||||
ASSERT_TRUE(ret_node != nullptr);
|
||||
auto output_node = fg->output();
|
||||
ASSERT_TRUE(output_node != nullptr);
|
||||
ASSERT_TRUE(output_node->isa<CNode>());
|
||||
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
ASSERT_EQ(output_cnode->inputs().size(), 3);
|
||||
ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>());
|
||||
ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>());
|
||||
ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim");
|
||||
ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>());
|
||||
ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x");
|
||||
ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>());
|
||||
|
||||
ASSERT_EQ(output_cnode->impl(), cnode->impl());
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test value related functions.
|
||||
/// Expectation: value related functions work as expected.
|
||||
TEST_F(TestMindApi, test_values) {
|
||||
int64_t one = 1;
|
||||
auto s = MakeValue("hello");
|
||||
auto i = MakeValue(one);
|
||||
auto i2 = MakeValue(2);
|
||||
auto b = MakeValue(true);
|
||||
auto f = MakeValue(3.14f);
|
||||
auto seq = MakeValue(std::vector<int64_t>{3, 4, 5});
|
||||
auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"}));
|
||||
|
||||
ASSERT_TRUE(s->isa<StringImm>());
|
||||
ASSERT_TRUE(i->isa<Int64Imm>());
|
||||
ASSERT_TRUE(i2->isa<Int64Imm>());
|
||||
ASSERT_TRUE(b->isa<BoolImm>());
|
||||
ASSERT_TRUE(f->isa<FP32Imm>());
|
||||
ASSERT_TRUE(seq->isa<ValueSequence>());
|
||||
ASSERT_TRUE(seq_str->isa<ValueSequence>());
|
||||
|
||||
ASSERT_EQ(GetValue<std::string>(s), "hello");
|
||||
ASSERT_EQ(GetValue<int64_t>(i), one);
|
||||
ASSERT_EQ(GetValue<int64_t>(i2), 2);
|
||||
ASSERT_TRUE(GetValue<bool>(b));
|
||||
ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f);
|
||||
|
||||
ASSERT_EQ(GetValue<std::string>(i), "");
|
||||
ASSERT_EQ(GetValue<int64_t>(s), 0);
|
||||
ASSERT_FALSE(GetValue<bool>(s));
|
||||
ASSERT_EQ(GetValue<float>(s), 0.0f);
|
||||
|
||||
auto seq_ptr = seq->cast<ValueSequencePtr>();
|
||||
ASSERT_TRUE(seq_ptr != nullptr);
|
||||
ASSERT_EQ(seq_ptr->size(), 3);
|
||||
ASSERT_EQ(seq_ptr->value().size(), 3);
|
||||
ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>());
|
||||
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3);
|
||||
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4);
|
||||
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5);
|
||||
|
||||
auto seq_values = GetValue<std::vector<int64_t>>(seq);
|
||||
ASSERT_EQ(seq_values.size(), 3);
|
||||
ASSERT_EQ(seq_values[0], 3);
|
||||
ASSERT_EQ(seq_values[1], 4);
|
||||
ASSERT_EQ(seq_values[2], 5);
|
||||
|
||||
auto str_values = GetValue<std::vector<std::string>>(seq_str);
|
||||
ASSERT_EQ(str_values.size(), 4);
|
||||
ASSERT_EQ(str_values[0], "this");
|
||||
ASSERT_EQ(str_values[1], "is");
|
||||
ASSERT_EQ(str_values[2], "mindspore");
|
||||
ASSERT_EQ(str_values[3], "api");
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test graph manager functions.
|
||||
/// Expectation: graph manager functions work as expected.
|
||||
TEST_F(TestMindApi, test_func_graph_manager) {
|
||||
// fg(x, y) { return myprim(add(x, y), 1); }
|
||||
auto fg = FuncGraph::Create();
|
||||
auto x = fg->add_parameter();
|
||||
x->set_name("x");
|
||||
auto y = fg->add_parameter();
|
||||
y->set_name("y");
|
||||
auto add = MakeShared<Primitive>("add");
|
||||
auto add_node = MakeShared<ValueNode>(add);
|
||||
auto add_cnode = fg->NewCNode({add_node, x, y});
|
||||
auto prim = MakeShared<Primitive>("myprim");
|
||||
auto prim_node = MakeShared<ValueNode>(prim);
|
||||
auto value_node = MakeShared<ValueNode>(MakeValue(1));
|
||||
auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
|
||||
fg->set_output(cnode);
|
||||
|
||||
auto mgr = FuncGraphManager::Manage(fg);
|
||||
ASSERT_TRUE(mgr != nullptr);
|
||||
ASSERT_TRUE(fg->manager() != nullptr);
|
||||
ASSERT_EQ(fg->manager()->impl(), mgr->impl());
|
||||
ASSERT_EQ(fg->manager(), mgr);
|
||||
|
||||
ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl());
|
||||
mgr->Replace(add_cnode, x);
|
||||
ASSERT_EQ(cnode->input(1)->impl(), x->impl());
|
||||
|
||||
mgr->SetEdge(cnode, 1, y);
|
||||
ASSERT_EQ(cnode->input(1)->impl(), y->impl());
|
||||
|
||||
mgr->AddEdge(cnode, x);
|
||||
ASSERT_EQ(cnode->size(), 4);
|
||||
ASSERT_EQ(cnode->input(3)->impl(), x->impl());
|
||||
|
||||
auto users = mgr->GetUsers(value_node);
|
||||
ASSERT_EQ(users.size(), 1);
|
||||
ASSERT_EQ(users[0].first, cnode);
|
||||
ASSERT_EQ(users[0].second, 2);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test value node utils.
|
||||
/// Expectation: value node utils work as expected.
|
||||
TEST_F(TestMindApi, test_value_node_utils) {
|
||||
auto fg = FuncGraph::Create();
|
||||
auto fg_node = MakeShared<ValueNode>(fg);
|
||||
auto prim = MakeShared<Primitive>("myprim");
|
||||
auto prim_node = MakeShared<ValueNode>(prim);
|
||||
auto one = MakeShared<ValueNode>(MakeValue(1));
|
||||
auto cnode = fg->NewCNode({fg_node, prim_node, one});
|
||||
|
||||
ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr);
|
||||
|
||||
auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
ASSERT_TRUE(fg1 != nullptr);
|
||||
ASSERT_TRUE(fg1->isa<FuncGraph>());
|
||||
|
||||
auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1));
|
||||
ASSERT_TRUE(prim1 != nullptr);
|
||||
ASSERT_TRUE(prim1->isa<Primitive>());
|
||||
|
||||
auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2));
|
||||
ASSERT_TRUE(imm != nullptr);
|
||||
ASSERT_TRUE(imm->isa<Int64Imm>());
|
||||
ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1);
|
||||
|
||||
auto value = GetValueNode(cnode->input(2));
|
||||
ASSERT_TRUE(value != nullptr);
|
||||
ASSERT_EQ(GetValue<int64_t>(value), 1);
|
||||
|
||||
ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr);
|
||||
ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr);
|
||||
ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr);
|
||||
|
||||
// Test NewValueNode.
|
||||
auto int_node = NewValueNode(1);
|
||||
auto bool_node = NewValueNode(true);
|
||||
auto float_node = NewValueNode(1.23f);
|
||||
auto str_node = NewValueNode("hello");
|
||||
|
||||
ASSERT_TRUE(int_node->value()->isa<Int64Imm>());
|
||||
ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1);
|
||||
ASSERT_TRUE(bool_node->value()->isa<BoolImm>());
|
||||
ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value());
|
||||
ASSERT_TRUE(float_node->value()->isa<FP32Imm>());
|
||||
ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f);
|
||||
ASSERT_TRUE(str_node->value()->isa<StringImm>());
|
||||
ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello");
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test SharedPtr.
|
||||
/// Expectation: SharedPtr work as expected.
|
||||
TEST_F(TestMindApi, test_object_ptr) {
|
||||
auto fg = FuncGraph::Create();
|
||||
auto fg_node = MakeShared<ValueNode>(fg);
|
||||
auto prim = MakeShared<Primitive>("myprim");
|
||||
auto prim_node = MakeShared<ValueNode>(prim);
|
||||
auto one = MakeShared<ValueNode>(MakeValue(1));
|
||||
auto cnode = fg->NewCNode({fg_node, prim_node, one});
|
||||
|
||||
ASSERT_TRUE(fg != nullptr);
|
||||
ASSERT_FALSE(!fg);
|
||||
ASSERT_TRUE(fg ? true : false);
|
||||
ASSERT_TRUE((*cnode).input(0) == fg_node);
|
||||
ASSERT_TRUE(cnode->input(0) == fg_node);
|
||||
ASSERT_TRUE(cnode.get()->input(0) == fg_node);
|
||||
|
||||
ASSERT_EQ(cnode->input(0), fg_node);
|
||||
ASSERT_EQ(cnode->input(1), prim_node);
|
||||
ASSERT_EQ(cnode->input(2), one);
|
||||
ASSERT_TRUE(cnode->input(0) != fg);
|
||||
|
||||
AnfNodePtr p = fg_node;
|
||||
ASSERT_TRUE(p == fg_node);
|
||||
ASSERT_TRUE(p->isa<ValueNode>());
|
||||
ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr);
|
||||
ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node);
|
||||
|
||||
p = cnode;
|
||||
ASSERT_TRUE(p == cnode);
|
||||
ASSERT_TRUE(p->isa<CNode>());
|
||||
ASSERT_TRUE(p->cast<CNodePtr>() != nullptr);
|
||||
ASSERT_TRUE(p->cast<CNodePtr>() == cnode);
|
||||
ASSERT_TRUE(p.get() == cnode.get());
|
||||
|
||||
ASSERT_TRUE(p != nullptr);
|
||||
ASSERT_FALSE(p == nullptr);
|
||||
ASSERT_TRUE(p > nullptr);
|
||||
ASSERT_FALSE(p < nullptr);
|
||||
ASSERT_TRUE(p >= nullptr);
|
||||
ASSERT_FALSE(p <= nullptr);
|
||||
|
||||
ASSERT_TRUE(nullptr != p);
|
||||
ASSERT_FALSE(nullptr == p);
|
||||
ASSERT_TRUE(nullptr < p);
|
||||
ASSERT_FALSE(nullptr > p);
|
||||
ASSERT_TRUE(nullptr <= p);
|
||||
ASSERT_FALSE(nullptr >= p);
|
||||
|
||||
AnfNodePtr q = fg_node;
|
||||
ASSERT_TRUE(p != q);
|
||||
ASSERT_TRUE(p > q);
|
||||
if (p.get()->impl() > q.get()->impl()) {
|
||||
ASSERT_TRUE(p > q);
|
||||
ASSERT_TRUE(p >= q);
|
||||
ASSERT_TRUE(q < p);
|
||||
ASSERT_TRUE(q <= p);
|
||||
} else {
|
||||
ASSERT_TRUE(p < q);
|
||||
ASSERT_TRUE(p <= q);
|
||||
ASSERT_TRUE(q > p);
|
||||
ASSERT_TRUE(q >= p);
|
||||
}
|
||||
|
||||
std::stringstream ss1;
|
||||
std::stringstream ss2;
|
||||
ss1 << p;
|
||||
ss2 << cnode.get()->impl().get();
|
||||
ASSERT_EQ(ss1.str(), ss2.str());
|
||||
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> mymap;
|
||||
mymap.emplace(p, q);
|
||||
mymap.emplace(q, p);
|
||||
ASSERT_TRUE(mymap.find(p) != mymap.end());
|
||||
ASSERT_TRUE(mymap.find(q) != mymap.end());
|
||||
ASSERT_TRUE(mymap[p] == q);
|
||||
ASSERT_TRUE(mymap[q] == p);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test Tensor API.
|
||||
/// Expectation: Tensor API work as expected.
|
||||
TEST_F(TestMindApi, test_tensor_api) {
|
||||
ShapeVector shape{1, 2, 3};
|
||||
auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
|
||||
|
||||
ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
|
||||
ASSERT_EQ(tensor->shape(), shape);
|
||||
ASSERT_EQ(tensor->DataSize(), 6);
|
||||
ASSERT_EQ(tensor->Size(), 24);
|
||||
|
||||
ShapeVector shape2{2, 3};
|
||||
tensor->set_data_type(kNumberTypeInt32);
|
||||
tensor->set_shape(shape2);
|
||||
ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
|
||||
ASSERT_EQ(tensor->shape(), shape2);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test utils API.
|
||||
/// Expectation: Tensor API work as expected.
|
||||
TEST_F(TestMindApi, test_api_utils) {
|
||||
// Test utils::isa, utils::cast.
|
||||
auto anf_node = NewValueNode("hello");
|
||||
ASSERT_TRUE(utils::isa<AnfNode>(anf_node));
|
||||
ASSERT_FALSE(utils::isa<AbstractBase>(anf_node));
|
||||
ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr);
|
||||
ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr);
|
||||
|
||||
anf_node = nullptr;
|
||||
ASSERT_FALSE(utils::isa<AnfNode>(anf_node));
|
||||
ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr);
|
||||
|
||||
// Test clone graph.
|
||||
auto fg = FuncGraph::Create();
|
||||
auto x = fg->add_parameter();
|
||||
x->set_name("x");
|
||||
auto y = fg->add_parameter();
|
||||
y->set_name("y");
|
||||
auto add = MakeShared<Primitive>("add");
|
||||
auto add_node = MakeShared<ValueNode>(add);
|
||||
auto add_cnode = fg->NewCNode({add_node, x, y});
|
||||
auto prim = MakeShared<Primitive>("myprim");
|
||||
auto prim_node = MakeShared<ValueNode>(prim);
|
||||
auto value_node = MakeShared<ValueNode>(MakeValue(1));
|
||||
auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
|
||||
fg->set_output(cnode);
|
||||
|
||||
auto cloned_fg = utils::CloneGraph(fg);
|
||||
ASSERT_TRUE(cloned_fg != nullptr);
|
||||
ASSERT_EQ(cloned_fg->parameters().size(), 2);
|
||||
auto new_output = cloned_fg->output();
|
||||
ASSERT_TRUE(new_output != nullptr);
|
||||
ASSERT_TRUE(new_output->isa<CNode>());
|
||||
ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size());
|
||||
ASSERT_TRUE(new_output != cnode);
|
||||
ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode);
|
||||
|
||||
// Test get pad mode.
|
||||
auto pm_lower = MakeValue("pad");
|
||||
auto pm_upper = MakeValue("PAD");
|
||||
ASSERT_EQ(utils::GetPadMode(pm_lower), 0);
|
||||
ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0);
|
||||
ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test logging API.
|
||||
/// Expectation: logging work as expected.
|
||||
TEST_F(TestMindApi, test_api_logging) {
|
||||
MS_LOG(DEBUG) << "hello debug";
|
||||
MS_LOG(INFO) << "hello info";
|
||||
MS_LOG(WARNING) << "hello warning";
|
||||
MS_LOG(ERROR) << "hello error";
|
||||
try {
|
||||
MS_LOG(EXCEPTION) << "hello exception";
|
||||
ASSERT_TRUE(false);
|
||||
} catch (...) {
|
||||
}
|
||||
ASSERT_TRUE(true);
|
||||
}
|
||||
} // namespace mindspore::api
|
Loading…
Reference in New Issue