Provide new mindspore core API classes

1. namespace is mindspore::api;
2. API header files located in mindspore/core/mindapi;
3. We use pimpl pattern to provide a wrapper layer for api;
4. Check mindapi_test.cc for usage examples.
This commit is contained in:
He Wei 2021-12-03 09:54:03 +08:00
parent 25aa2bee49
commit d2bb6303b7
46 changed files with 3086 additions and 98 deletions

View File

@ -221,6 +221,8 @@ if(PLATFORM_ARM64)
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@ -298,6 +300,8 @@ elseif(PLATFORM_ARM32)
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@ -365,6 +369,8 @@ elseif(WIN32)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@ -409,6 +415,8 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@ -430,6 +438,10 @@ else()
PATTERN "train*" EXCLUDE PATTERN "delegate.h" EXCLUDE PATTERN "lite_session.h" EXCLUDE)
install(FILES ${API_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/api
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${MINDAPI_BASE_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${MINDAPI_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/ir
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${ABSTRACT_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${API_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api/ir

View File

@ -23,6 +23,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"ir/*.cc"
"utils/*.cc"
"load_mindir/*.cc"
"mindapi/src/*.cc"
)
if(ENABLE_SECURITY)

View File

@ -24,8 +24,7 @@
#include "utils/visible.h"
#include "api/ir/func_graph_manager.h"
namespace mindspore::api {
namespace mindspore::deprecated::api {
/// \brief FuncGraph defines interface for a function graph.
class MS_CORE_API FuncGraph {
public:
@ -147,5 +146,12 @@ class MS_CORE_API FuncGraph {
/// \return The function graph if the input is value node that holds the graph, nullptr otherwise.
static FuncGraphPtr GetFuncGraphFromAnfNode(const AnfNodePtr &input);
};
} // namespace mindspore::api
#ifndef USE_DEPRECATED_API
#define USE_DEPRECATED_API
namespace mindspore {
namespace api = deprecated::api;
}
#endif
} // namespace mindspore::deprecated::api
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_H_

View File

@ -26,8 +26,7 @@
#include "utils/hashing.h"
#include "ir/anf.h"
namespace mindspore::api {
namespace mindspore::deprecated::api {
class FuncGraph;
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
@ -80,7 +79,13 @@ class MS_CORE_API FuncGraphManager {
/// \return The manager that manages the given function graph.
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
};
} // namespace mindspore::deprecated::api
} // namespace mindspore::api
#ifndef USE_DEPRECATED_API
#define USE_DEPRECATED_API
namespace mindspore {
namespace api = deprecated::api;
}
#endif
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_MANAGER_H_

View File

@ -52,6 +52,7 @@ static const std::vector<std::string> sub_module_names = {
"HCCL_ADPT", // SM_HCCL_ADPT
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK
"GE", // SM_GE
"API", // SM_API
};
const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_names[(module_id % NUM_SUBMODUES)]; }

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,87 +19,6 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
#define MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
namespace mindspore {
//
// Supported meta type
//
enum TypeId : int {
kTypeUnknown = 0,
kMetaTypeBegin = kTypeUnknown,
kMetaTypeType, // Type
kMetaTypeAnything,
kMetaTypeObject,
kMetaTypeTypeType, // TypeType
kMetaTypeProblem,
kMetaTypeExternal,
kMetaTypeNone,
kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd,
//
// Object types
//
kObjectTypeBegin = kMetaTypeEnd,
kObjectTypeNumber,
kObjectTypeString,
kObjectTypeList,
kObjectTypeTuple,
kObjectTypeSlice,
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeRowTensorType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,
kObjectTypeFunction,
kObjectTypeJTagged,
kObjectTypeSymbolicKeyType,
kObjectTypeEnvType,
kObjectTypeRefKey,
kObjectTypeRef,
kObjectTypeEnd,
//
// Number Types
//
kNumberTypeBegin = kObjectTypeEnd,
kNumberTypeBool,
kNumberTypeInt,
kNumberTypeInt8,
kNumberTypeInt16,
kNumberTypeInt32,
kNumberTypeInt64,
kNumberTypeUInt,
kNumberTypeUInt8,
kNumberTypeUInt16,
kNumberTypeUInt32,
kNumberTypeUInt64,
kNumberTypeFloat,
kNumberTypeFloat16,
kNumberTypeFloat32,
kNumberTypeFloat64,
kNumberTypeComplex,
kNumberTypeComplex64,
kNumberTypeComplex128,
kNumberTypeInt4,
kNumberTypeGLUInt,
kNumberTypeEnd,
//
// Monad Types
//
kMonadTypeBegin = kNumberTypeEnd,
kObjectTypeMonad,
kObjectTypeUMonad,
kObjectTypeIOMonad,
kMonadTypeEnd,
//
// Sparse Types
//
// Sparse types is placed at the end of enum,
// in order to keep fit with the type of existing model on the lite side.
kSparseTypeBegin = kMonadTypeEnd,
kObjectTypeCSRTensorType,
kSparseTypeEnd
};
} // namespace mindspore
#include "mindapi/base/type_id.h"
#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_

View File

@ -153,7 +153,7 @@ class FuncGraphBase : public Value {
MS_DECLARE_PARENT(FuncGraphBase, Value);
};
class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
public:
using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
@ -265,7 +265,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
FuncGraphManagerPtr manager() const { return manager_.lock(); }
void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
deprecated::api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
std::string ToString() const override;
GraphDebugInfoPtr debug_info();

View File

@ -55,9 +55,9 @@ class FuncGraphTransaction;
class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
using AnfNodeIndexSet = api::AnfNodeIndexSet;
using AnfNodeIndexSet = deprecated::api::AnfNodeIndexSet;
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
using NodeUsersMap = api::NodeUsersMap;
using NodeUsersMap = deprecated::api::NodeUsersMap;
using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
@ -277,7 +277,8 @@ class FuncGraphJTotalComputer final : public DepComputer {
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
};
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>,
public deprecated::api::FuncGraphManager {
public:
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager() {

View File

@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_BASE_H_
#define MINDSPORE_CORE_MINDAPI_BASE_BASE_H_
#include <cstdint>
#include <string>
#include <memory>
#include "mindapi/base/macros.h"
#include "mindapi/base/type_traits.h"
#include "mindapi/base/shared_ptr.h"
namespace mindspore {
class Base;
}
namespace mindspore::api {
/// \brief Base is the base class of many api classes, which provides basic interfaces.
class MIND_API Base {
public:
/// \brief Create an instance from the given implementation object.
///
/// \param[in] impl The shared_ptr to the implementation object.
explicit Base(const std::shared_ptr<mindspore::Base> &impl);
/// \brief Destructor of Base.
virtual ~Base() = default;
/// \brief Get the id of this class.
///
/// \return The id of this class.
static uint32_t ClassId();
/// \brief Get the shared_ptr to the underly implementation object.
///
/// \return The shared_ptr to the underly implementation object.
const std::shared_ptr<mindspore::Base> &impl() const { return impl_; }
/// \brief Get the string representation of this object.
///
/// \return The string representation.
std::string ToString() const;
/// \brief Check whether this object is an instance of the given class.
///
/// \return True if this object is an instance of the given class, false otherwise.
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>>
inline bool isa() const {
return IsFromClassId(T::ClassId());
}
/// \brief Cast this object to a pointer with the given pointer class.
///
/// \return A non-null pointer if cast success, nullptr otherwise.
template <typename T, typename U = typename std::enable_if_t<is_wrapper_ptr<T>::value, typename T::element_type>>
inline T cast() {
if (isa<U>()) {
return MakeShared<U>(impl_);
}
return nullptr;
}
protected:
bool IsFromClassId(uint32_t class_id) const;
const std::shared_ptr<mindspore::Base> impl_;
};
#define MIND_API_BASE_MEMBER(current_class) \
explicit current_class(const std::shared_ptr<mindspore::Base> &impl); \
~current_class() override = default; \
static uint32_t ClassId()
using BasePtr = SharedPtr<Base>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_BASE_BASE_H_

View File

@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_
#define MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_
#include <cstdint>
#include <memory>
#include <sstream>
#include <utility>
#include "mindapi/base/macros.h"
namespace mindspore::api {
enum class LogLevel : uint8_t { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION };
class LogWriterImpl;
/// \brief LogStream represents a stream to write log messages.
/// This class is not expected for directly use, use MS_LOG instead.
class LogStream {
public:
/// \brief Write log message to this LogStream.
///
/// \param[in] value The object to be written.
template <typename T>
LogStream &operator<<(T &&value) noexcept {
(void)stream_.operator<<(std::forward<T>(value));
return *this;
}
private:
friend class LogWriterImpl;
std::stringstream stream_;
};
/// \brief LogWriter defines interface for log message output.
/// This class is not expected for directly use, use MS_LOG instead.
class MIND_API LogWriter {
public:
/// \brief Create a LogWriter with the given log level, file name, line number and function name.
///
/// \param[in] level The log level.
/// \param[in] file The file name.
/// \param[in] line The line number.
/// \param[in] func The function name.
LogWriter(LogLevel level, const char *file, int line, const char *func);
/// \brief Destructor for LogWriter.
~LogWriter();
/// \brief Output log message from the input log stream.
///
/// \param[in] stream The input log stream.
void operator<(const LogStream &stream) const noexcept;
/// \brief Output log message from the input log stream and then throw exception.
///
/// \param[in] stream The input log stream.
void operator^(const LogStream &stream) const __attribute__((noreturn));
/// \brief Check whether the given log level is enabled or not.
///
/// \return True if the log level is enabled, false otherwise.
static bool IsEnabled(LogLevel level);
private:
std::unique_ptr<LogWriterImpl> impl_;
};
#define MIND_LOG_STREAM mindspore::api::LogStream()
#define MIND_LOG_WRITER mindspore::api::LogWriter
#define MIND_LOG_LEVEL(L) mindspore::api::LogLevel::L
#define MIND_LOG_THROW(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) ^ MIND_LOG_STREAM
#define MIND_LOG_WRITE(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) < MIND_LOG_STREAM
#define MIND_LOG_IF(L) \
if (MIND_LOG_WRITER::IsEnabled(MIND_LOG_LEVEL(L))) MIND_LOG_WRITE(L)
#define MIND_LOG_DEBUG MIND_LOG_IF(DEBUG)
#define MIND_LOG_INFO MIND_LOG_IF(INFO)
#define MIND_LOG_WARNING MIND_LOG_IF(WARNING)
#define MIND_LOG_ERROR MIND_LOG_IF(ERROR)
#define MIND_LOG_EXCEPTION MIND_LOG_THROW(EXCEPTION)
#define MIND_LOG(level) MIND_LOG_##level
#if !defined(MIND_LOG_NO_MS_LOG) && !defined(MS_LOG)
#define MS_LOG(level) MIND_LOG_##level
#endif
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_

View File

@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_
#define MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_
#if (defined(_WIN32) || defined(__WIN32__) || defined(WIN32) || defined(__CYGWIN__))
#ifdef BUILDING_DLL
#define MIND_API __declspec(dllexport)
#else
#define MIND_API __declspec(dllimport)
#endif
#else
#define MIND_API __attribute__((visibility("default")))
#endif
#endif // MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_
#define MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_
#include <cstdint>
#include <vector>
using ShapeVector = std::vector<int64_t>;
#endif // MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_

View File

@ -0,0 +1,180 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_
#define MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_
#include <cstdint>
#include <memory>
#include <utility>
#include <ostream>
#include <functional>
namespace mindspore::api {
/// \brief SharedPtr wraps a std::shared_ptr and provides wrapper functions according the underlying implementation.
template <typename T>
class SharedPtr {
public:
using element_type = T;
constexpr SharedPtr() noexcept = default;
constexpr SharedPtr(std::nullptr_t) noexcept : SharedPtr() {} // NOLINT
template <typename U>
explicit SharedPtr(std::shared_ptr<U> &&ptr) : ptr_(std::move(ptr)) {}
template <typename U>
SharedPtr(const SharedPtr<U> &other) : ptr_(other.ptr_) {}
template <typename U>
SharedPtr(SharedPtr<U> &&other) : ptr_(std::move(other.ptr_)) {}
template <typename U>
SharedPtr &operator=(const SharedPtr<U> &other) {
ptr_ = other.ptr_;
return *this;
}
template <typename U>
SharedPtr &operator=(SharedPtr<U> &&other) {
ptr_ = std::move(other.ptr_);
return *this;
}
~SharedPtr() = default;
std::uintptr_t addr() const { return (ptr_ == nullptr) ? 0 : reinterpret_cast<std::uintptr_t>(ptr_->impl().get()); }
element_type &operator*() const noexcept { return *ptr_; }
element_type *operator->() const noexcept { return ptr_.get(); }
element_type *get() const noexcept { return ptr_.get(); }
explicit operator bool() const { return addr() != 0; }
private:
template <typename U>
friend class SharedPtr;
std::shared_ptr<element_type> ptr_;
};
template <typename T, typename U>
inline bool operator==(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() == b.addr();
}
template <typename T>
inline bool operator==(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() == 0;
}
template <typename T>
inline bool operator==(std::nullptr_t, const SharedPtr<T> &a) noexcept {
return a.addr() == 0;
}
template <typename T, typename U>
inline bool operator!=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() != b.addr();
}
template <typename T>
inline bool operator!=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() != 0;
}
template <typename T>
inline bool operator!=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
return a.addr() != 0;
}
template <typename T, typename U>
inline bool operator<(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() < b.addr();
}
template <typename T>
inline bool operator<(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() < 0;
}
template <typename T>
inline bool operator<(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr < ptr' is false only when ptr is nullptr.
return a.addr() != 0;
}
template <typename T, typename U>
inline bool operator>(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() > b.addr();
}
template <typename T>
inline bool operator>(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() > 0;
}
template <typename T>
inline bool operator>(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr > ptr' is always false.
return false;
}
template <typename T, typename U>
inline bool operator<=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() <= b.addr();
}
template <typename T>
inline bool operator<=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() <= 0;
}
template <typename T>
inline bool operator<=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr <= ptr' is always true.
return true;
}
template <typename T, typename U>
inline bool operator>=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() >= b.addr();
}
template <typename T>
inline bool operator>=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() >= 0;
}
template <typename T>
inline bool operator>=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr >= ptr' is true only when ptr is nullptr.
return a.addr() == 0;
}
template <typename T, typename U, typename V>
inline std::basic_ostream<U, V> &operator<<(std::basic_ostream<U, V> &os, const SharedPtr<T> &a) {
return (os << reinterpret_cast<void *>(a.addr()));
}
/// \brief Constructs an object of type T and wraps it in a SharedPtr.
///
/// \param[in] args The parameter list for the constructor of T.
template <typename T, typename... Args>
inline SharedPtr<T> MakeShared(Args &&... args) {
auto ptr = std::make_shared<T>(std::forward<Args>(args)...);
return SharedPtr<T>(std::move(ptr));
}
} // namespace mindspore::api
namespace std {
template <typename T>
struct hash<mindspore::api::SharedPtr<T>> {
size_t operator()(const mindspore::api::SharedPtr<T> &ptr) const noexcept { return static_cast<size_t>(ptr.addr()); }
};
} // namespace std
#endif // MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_

View File

@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
namespace mindspore {
/// \brief TypeId defines data type identifiers.
enum TypeId : int {
kTypeUnknown = 0,
//
// Meta types.
//
kMetaTypeBegin = kTypeUnknown,
kMetaTypeType, // Type
kMetaTypeAnything,
kMetaTypeObject,
kMetaTypeTypeType, // TypeType
kMetaTypeProblem,
kMetaTypeExternal,
kMetaTypeNone,
kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd,
//
// Object types
//
kObjectTypeBegin = kMetaTypeEnd,
kObjectTypeNumber,
kObjectTypeString,
kObjectTypeList,
kObjectTypeTuple,
kObjectTypeSlice,
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeRowTensorType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,
kObjectTypeFunction,
kObjectTypeJTagged,
kObjectTypeSymbolicKeyType,
kObjectTypeEnvType,
kObjectTypeRefKey,
kObjectTypeRef,
kObjectTypeEnd,
//
// Number Types
//
kNumberTypeBegin = kObjectTypeEnd,
kNumberTypeBool,
kNumberTypeInt,
kNumberTypeInt8,
kNumberTypeInt16,
kNumberTypeInt32,
kNumberTypeInt64,
kNumberTypeUInt,
kNumberTypeUInt8,
kNumberTypeUInt16,
kNumberTypeUInt32,
kNumberTypeUInt64,
kNumberTypeFloat,
kNumberTypeFloat16,
kNumberTypeFloat32,
kNumberTypeFloat64,
kNumberTypeComplex,
kNumberTypeComplex64,
kNumberTypeComplex128,
kNumberTypeInt4,
kNumberTypeGLUInt,
kNumberTypeEnd,
//
// Monad Types
//
kMonadTypeBegin = kNumberTypeEnd,
kObjectTypeMonad,
kObjectTypeUMonad,
kObjectTypeIOMonad,
kMonadTypeEnd,
//
// Sparse Types
//
// Sparse types is placed at the end of enum,
// in order to keep fit with the type of existing model on the lite side.
kSparseTypeBegin = kMonadTypeEnd,
kObjectTypeCSRTensorType,
kSparseTypeEnd
};
} // namespace mindspore
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_
#include <vector>
#include <memory>
#include <type_traits>
#include "mindapi/base/shared_ptr.h"
namespace mindspore::api {
template <typename T>
struct is_wrapper_ptr : public std::false_type {};
template <typename T>
struct is_wrapper_ptr<SharedPtr<T>> : public std::true_type {};
template <typename T>
struct is_shared_ptr : public std::false_type {};
template <typename T>
struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
template <typename T>
struct is_vector : public std::false_type {};
template <typename T, typename A>
struct is_vector<std::vector<T, A>> : public std::true_type {};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_

View File

@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
#define MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/shape.h"
#include "mindapi/ir/type.h"
#include "mindapi/ir/value.h"
namespace mindspore::api {
/// \brief AbstractBase defines base interfaces for abstract of an anf node.
class MIND_API AbstractBase : public Base {
public:
MIND_API_BASE_MEMBER(AbstractBase);
/// \brief Clone an abstract from this abstract.
///
/// \return A pointer to the cloned abstract.
AbstractBasePtr Clone() const;
/// \brief Get the abstract type.
///
/// \return A pointer to the Type.
TypePtr type() const;
/// \brief Get the abstract value.
///
/// \return A pointer to the Value.
ValuePtr value() const;
/// \brief Set the type for this abstract.
///
/// \param[in] type The type to be set.
void set_type(const TypePtr &type);
/// \brief Set the value for this abstract.
///
/// \param[in] value The value to be set.
void set_value(const ValuePtr &value);
};
/// \brief AbstractTensor describes a tensor's type, shape and value.
class MIND_API AbstractTensor : public AbstractBase {
public:
MIND_API_BASE_MEMBER(AbstractTensor);
/// \brief Create AbstractTensor from the given type and shape.
///
/// \param[in] type The data type id of the tensor.
/// \param[in] shape The shape of the tensor.
AbstractTensor(TypeId type, const ShapeVector &shape);
/// \brief Get the element abstract.
///
/// \return A pointer to the element abstract.
AbstractBasePtr element() const;
/// \brief Get the shape of the abstract.
///
/// \return A pointer to the shape.
ShapePtr shape() const;
};
using AbstractTensorPtr = SharedPtr<AbstractTensor>;
/// \brief AbstractSequence describes the abstract for a tuple or list.
class MIND_API AbstractSequence : public AbstractBase {
public:
MIND_API_BASE_MEMBER(AbstractSequence);
/// \brief Get element abstracts.
///
/// \return A vector of element abstracts.
AbstractBasePtrList elements() const;
};
using AbstractSequencePtr = SharedPtr<AbstractSequence>;
/// \brief AbstractTuple describes the abstract for a tuple.
class MIND_API AbstractTuple : public AbstractSequence {
public:
MIND_API_BASE_MEMBER(AbstractTuple);
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_

View File

@ -0,0 +1,232 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_
#define MINDSPORE_CORE_MINDAPI_IR_ANF_H_
#include <vector>
#include <string>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/abstract.h"
#include "mindapi/ir/primitive.h"
#include "mindapi/ir/value.h"
namespace mindspore::api {
/// \brief AnfNode is the basic class of the IR graph node.
class MIND_API AnfNode : public Base {
public:
MIND_API_BASE_MEMBER(AnfNode);
/// \brief Obtain detailed information about scope namespace.
///
/// \return Detailed information about scope namespace.
std::string fullname_with_scope() const;
/// \brief Obtain the inferred abstract value of this AnfNode.
///
/// \return The inferred abstract value.
AbstractBasePtr abstract() const;
/// \brief Set the abstract value of this AnfNode.
///
/// \param[in] abs New abstract value.
void set_abstract(const AbstractBasePtr &abs);
};
/// \brief CNode represents a compute node with a set of input nodes.
class MIND_API CNode : public AnfNode {
public:
MIND_API_BASE_MEMBER(CNode);
/// \brief Get the number of inputs.
///
/// \return The number of inputs in this CNode.
size_t size() const;
/// \brief Get the input node of the given index.
///
/// \param[in] i The given index.
///
/// \return The input node of the given index.
AnfNodePtr input(size_t i) const;
/// \brief Get the input nodes.
///
/// \return The input nodes of this CNode.
std::vector<AnfNodePtr> inputs() const;
/// \brief Set the input nodes for this CNode.
///
/// \param[in] inputs Input nodes.
void set_inputs(const std::vector<AnfNodePtr> &inputs);
/// \brief Add an input node to this CNode.
///
/// \param[in] input the input node to be added.
void add_input(const AnfNodePtr &input);
/// \brief Set fullname_with_scope for this CNode.
///
/// \param[in] full_name The fullname_with_scope.
void set_fullname_with_scope(const std::string &full_name);
/// \brief Add a new attribute to this CNode.
///
/// \param[in] name The name of the new attribute.
/// \param[in] attr The value of the new attribute.
void AddAttr(const std::string &name, const ValuePtr &attr);
/// \brief Erase the attribute with the given name.
///
/// \param[in] name The name of attribute.
void EraseAttr(const std::string &name);
/// \brief Get the attribute with the given name.
///
/// \param[in] name The name of attribute.
/// \return Attribute.
ValuePtr GetAttr(const std::string &name) const;
};
using CNodePtr = SharedPtr<CNode>;
/// \brief Parameter represents the parameter inputs of a function.
class MIND_API Parameter : public AnfNode {
public:
MIND_API_BASE_MEMBER(Parameter);
/// \brief Get the name of this Parameter.
///
/// \return The name.
std::string name() const;
/// \brief Set the name of this Parameter.
///
/// \param[in] The name.
void set_name(const std::string &name);
/// \brief Check if there is a default parameter.
///
/// \return True if this Parameter has a default parameter, otherwise false.
bool has_default() const;
/// \brief Set the default parameter.
///
/// \param[in] param The default parameter.
void set_default_param(const ValuePtr &param);
/// \brief Get the default parameter.
///
/// \return The default parameter.
ValuePtr default_param() const;
};
using ParameterPtr = SharedPtr<Parameter>;
/// \brief ValueNode is a graph node that hold a value.
class MIND_API ValueNode : public AnfNode {
public:
MIND_API_BASE_MEMBER(ValueNode);
/// \brief Create ValueNode with the given value.
///
/// \param[in] value The value of this ValueNode.
explicit ValueNode(const ValuePtr &value);
/// \brief Get the value of this ValueNode.
///
/// \return The value.
ValuePtr value() const;
};
using ValueNodePtr = SharedPtr<ValueNode>;
// === ANF utility functions === //
/// \brief Create a ValueNode with the given value.
///
/// \param[in] value The given value.
///
/// \return The created ValueNode.
inline ValueNodePtr NewValueNode(const ValuePtr &value) { return MakeShared<ValueNode>(value); }
/// \brief Create a ValueNode with the given primitive type value.
///
/// \param[in] value The given primitive type value.
///
/// \return The created ValueNode.
template <typename T>
inline ValueNodePtr NewValueNode(T value) {
return NewValueNode(MakeValue(value));
}
/// \brief Get the value from a node if it is a ValueNode.
///
/// \param[in] node The node which may hold a value.
///
/// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set.
inline ValuePtr GetValueNode(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
return value_node->value();
}
/// \brief Get the value with the given type from a node if it is a ValueNode.
///
/// \param[in] node The node which may hold a value.
///
/// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set, or value type is mismatch.
template <typename T, typename = typename std::enable_if_t<
is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>>
inline T GetValueNode(const AnfNodePtr &node) {
auto value = GetValueNode(node);
if (value == nullptr) {
return nullptr;
}
return value->cast<T>();
}
/// \brief Check whether the given node is a cnode with the given Primitive as the first input.
///
/// \param[in] node The given node to be checked.
/// \param[in] prim The Primitive value, nullptr means match any Primitive.
///
/// \return True if the node is cnode and the first input is the given Primitive, false otherwise.
MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr);
/// \brief Check whether the given node is a ValueNode with the given Primitive.
///
/// \param[in] node The given node to be checked.
/// \param[in] prim The Primitive value.
///
/// \return True if the given node is a ValueNode with the given Primitive, false otherwise.
MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim);
/// \brief Check if a node is a data node.
/// Some nodes may be used internally to pass some non-data states, those nodes are not data nodes.
///
/// \param[in] node The node to be checked.
///
/// \return True if the node is a data node, false otherwise.
MIND_API bool IsDataNode(const AnfNodePtr &node);
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_ANF_H_

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
#define MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
#include <vector>
#include "mindapi/base/shared_ptr.h"
namespace mindspore::api {
class AnfNode;
using AnfNodePtr = SharedPtr<AnfNode>;
using AnfNodePtrList = std::vector<AnfNodePtr>;
class Value;
using ValuePtr = SharedPtr<Value>;
class Primitive;
using PrimitivePtr = SharedPtr<Primitive>;
class Type;
using TypePtr = SharedPtr<Type>;
class AbstractBase;
using AbstractBasePtr = SharedPtr<AbstractBase>;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
class Shape;
using ShapePtr = SharedPtr<Shape>;
class FuncGraph;
using FuncGraphPtr = SharedPtr<FuncGraph>;
class FuncGraphManager;
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_

View File

@ -0,0 +1,193 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
#define MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/anf.h"
#include "mindapi/ir/primitive.h"
#include "mindapi/ir/value.h"
#include "mindapi/ir/utils.h"
namespace mindspore {
class FuncGraphManager;
}
namespace mindspore::api {
/// \brief FuncGraph defines interface for a function graph.
class MIND_API FuncGraph : public Value {
public:
MIND_API_BASE_MEMBER(FuncGraph);
/// \brief Get the input parameters.
///
/// \return Input parameters of this graph.
std::vector<AnfNodePtr> get_inputs() const;
/// \brief Get all parameters.
///
/// \return All parameters of this graph.
std::vector<AnfNodePtr> parameters() const;
/// \brief Adds a parameter to this graph.
///
/// \param[in] p The parameter to be added.
void add_parameter(const ParameterPtr &p);
/// \brief Adds a new parameter to this graph.
///
/// \return The new added parameter.
ParameterPtr add_parameter();
/// \brief Get the output node.
///
/// \return The output node, nullptr if output not set.
AnfNodePtr output() const;
/// \brief Get the return CNode.
///
/// \return The return CNode, nullptr if no return node.
CNodePtr get_return() const;
/// \brief Set the output node.
///
/// \param[in] value The output node to be set.
/// \param[in] force_new_ret If true, a new return node is always created.
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
/// \brief Set the return node.
///
/// \param[in] cnode The return CNode to be set.
void set_return(const CNodePtr &cnode);
/// \brief Creates a new CNode in this graph.
///
/// \param[in] inputs The input nodes of the new CNode.
///
/// \return The created CNode.
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
/// \brief Creates a new primitive CNode in this graph.
///
/// \param[in] primitive The primitive of the new CNode.
/// \param[in] prim_inputs The argument inputs of the primitive CNode.
///
/// \return The created primitive CNode.
CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
/// \brief Get all nodes in this graph.
///
/// \return All nodes in this graph.
std::vector<AnfNodePtr> nodes() const;
/// \brief Check whether an attribute is set for this graph.
///
/// \param[in] key The attribute key (name).
///
/// \return True if the attribute with the given key is set, false otherwise.
bool has_attr(const std::string &key) const;
/// \brief Get an attribute value by its key.
///
/// \param[in] key The attribute key (name).
///
/// \return The attribute value for the given key, nullptr if attribute not found.
ValuePtr get_attr(const std::string &key) const;
/// \brief Set an attribute value.
///
/// \param[in] key The attribute key (name).
/// \param[in] value The attribute value.
void set_attr(const std::string &key, const ValuePtr &value);
/// \brief Get the manager for this graph.
///
/// \return The manager of this graph, nullptr if not set.
FuncGraphManagerPtr manager() const;
/// \brief Creates an empty function graph.
///
/// \return The created function graph.
static FuncGraphPtr Create();
/// \brief Topological sort a graph from the given end node.
///
/// \param[in] node The end node of the graph to be sorted.
///
/// \return The sorted nodes.
static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node);
};
/// \brief FuncGraphManager defines interface for function graph management.
class MIND_API FuncGraphManager {
public:
/// \brief Create FuncGraphManager with the given implementor object.
///
/// \param[in] impl The pointer to the implementor object.
explicit FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl);
/// \brief Get the shared_ptr to the underly implementation object.
///
/// \return The shared_ptr to the underly implementation object.
const std::shared_ptr<mindspore::FuncGraphManager> &impl() const { return impl_; }
/// \brief Replace an old node with a new node, related edges are all updated.
///
/// \param[in] old_node The old node to be replaced.
/// \param[in] new_node The new node that replace the old one.
///
/// \return True if the node is successfully replaced, false otherwise.
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
/// \brief Change an existed edge by replace its input node.
///
/// \param[in] node The output node of the edge.
/// \param[in] index The input index in output node.
/// \param[in] value The new input node of the edge.
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
/// \brief Adds a new edge between the given two nodes.
///
/// \param[in] node The output node of the edge.
/// \param[in] value The input node of the edge.
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
/// \brief Find users of the given node.
///
/// \param[in] node The node.
///
/// \return Users of the given node, empty if user not found.
std::vector<std::pair<AnfNodePtr, int>> GetUsers(const AnfNodePtr &node) const;
/// \brief Manage the give function graph.
///
/// \param[in] func_graph The function graph to be managed.
/// \param[in] manage If true, the created manager will be set in the graph.
///
/// \return The manager that manages the given function graph.
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
private:
const std::shared_ptr<mindspore::FuncGraphManager> impl_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_

View File

@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_
#define MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_
#include <vector>
#include <string>
#include <unordered_map>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/value.h"
namespace mindspore::api {
/// \brief Primitive defines a primitive operator.
class MIND_API Primitive : public Value {
public:
MIND_API_BASE_MEMBER(Primitive);
/// \brief Create primitive with the given name.
///
/// \param[in] name The primitive name.
explicit Primitive(const std::string &name);
/// \brief Get name of the primitive.
///
/// \return The name of primitive.
const std::string &name() const;
/// \brief Add attribute to primitive.
///
/// \param[in] name The attribute name.
/// \param[in] attr The attribute value.
/// \return The primitive to which attribute has been added.
Primitive &AddAttr(const std::string &name, const ValuePtr &attr);
/// \brief Add attributes by using a map, all elements of the map will be added to this primitive.
///
/// \param[in] attrs The attribute map needs to be added in the primitive attribute.
/// \return The primitive to which attribute has been added.
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs);
/// \brief Erase attribute to the primitive attribute map.
///
/// \param[in] name The attribute name.
void EraseAttr(const std::string &name);
/// \brief Get attribute value by name.
///
/// \param[in] name the attribute name.
/// \return The value of the attribute, null if attribute name not found.
ValuePtr GetAttr(const std::string &name) const;
/// \brief Check If Primitive has an attribute with then given name.
///
/// \param[in] name The attribute name.
/// \return True if there is an attribute with the given name, otherwise false.
bool HasAttr(const std::string &name) const;
/// \brief Get all attributes of this primitive as a map.
///
/// \return The attribute map of this primitive.
std::unordered_map<std::string, ValuePtr> attrs() const;
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_
#define MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_
#include "mindapi/base/base.h"
#include "mindapi/base/shape_vector.h"
#include "mindapi/ir/common.h"
namespace mindspore::api {
/// \brief Shape defines dimensions of a tensor.
class MIND_API Shape : public Base {
public:
MIND_API_BASE_MEMBER(Shape);
/// \brief Get the shape dimensions.
///
/// \return The shape dimensions.
const ShapeVector &shape() const;
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_

View File

@ -0,0 +1,93 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_
#define MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_
#include <cstdint>
#include "mindapi/base/base.h"
#include "mindapi/base/shape_vector.h"
#include "mindapi/base/type_id.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/value.h"
namespace mindspore::api {
/// \brief Tensor represents a multi-dimensional array of elements.
class MIND_API Tensor : public Value {
public:
MIND_API_BASE_MEMBER(Tensor);
/// \brief Create a lazy allocated tensor.
///
/// \param[in] data_type [TypeId] Data type of the tensor.
/// \param[in] shape The shape represented by ShapeVector of the tensor.
Tensor(TypeId data_type, const ShapeVector &shape);
/// \brief Create a tensor with input data buffer.
///
/// \param[in] data_type [TypeId] Data type of the tensor.
/// \param[in] shape The shape represented by ShapeVector of the tensor.
/// \param[in] data The input data to be copied into tensor.
/// \param[in] data_len The length of data in bytes.
Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);
/// \brief Get the shape of the tensor.
/// The shape of a tensor is stored in a vector<int64_t>. Each element of the
/// vector represents the size of a dimension of the tensor. The order of each
/// element in the vector is the same as the the dimension's order it represents.
///
/// \return A vector<int64_t> which represents the shape of the tensor.
const ShapeVector &shape() const;
/// \brief Set the shape of tensor.
///
/// \param[in] shape The shape to be set.
void set_shape(const ShapeVector &shape);
/// \brief Get the data type of the tensor.
///
/// \return The data type of the tensor.
TypeId data_type() const;
/// \brief Set the data type of the tensor.
///
/// \param[in] data_type The data type to be set.
void set_data_type(const TypeId data_type);
/// \brief Get The pointer to the underlying memory block for data storage.
///
/// \return The pointer to the underlying data.
const void *data() const;
/// \brief Get The pointer to the underlying memory block for data storage.
///
/// \return The pointer to the underlying data.
void *data();
/// \brief Get tensor data size.
///
/// \return The total number of elements in the tensor.
int DataSize() const;
/// \brief Get tensor data size in bytes.
///
/// \return The total number of bytes for the tensor data.
std::size_t Size() const;
};
using TensorPtr = SharedPtr<Tensor>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_TYPE_H_
#define MINDSPORE_CORE_MINDAPI_IR_TYPE_H_
#include "mindapi/base/base.h"
#include "mindapi/base/type_id.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/value.h"
namespace mindspore::api {
/// \brief Type defines the type of a value.
class MIND_API Type : public Value {
public:
MIND_API_BASE_MEMBER(Type);
/// \brief Get the id of the Type object.
///
/// \return The id of the Type object.
TypeId type_id() const;
/// \brief Get the number type of the Type object.
///
/// \return The number type of this Type object, kTypeUnknown if this is not a number type.
TypeId number_type() const;
/// \brief Get the Type according to a TypeId.
///
/// \param[in] id The id of the type.
///
/// \return The pointer to the Type.
static TypePtr GetType(TypeId id);
/// \brief Get data size in bytes for the type according to a TypeId.
///
/// \param[in] id The id of the type.
///
/// \return The data size in bytes for the Type.
static size_t GetSize(TypeId id);
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_TYPE_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_UTILS_H_
#define MINDSPORE_CORE_MINDAPI_IR_UTILS_H_
#include "mindapi/base/base.h"
#include "mindapi/base/shared_ptr.h"
#include "mindapi/base/type_traits.h"
#include "mindapi/ir/anf.h"
#include "mindapi/ir/value.h"
#include "mindapi/ir/func_graph.h"
namespace mindspore::api::utils {
/// \brief Check whether the given object is an instance of the given class.
///
/// \param[in] ptr The pointer to the given object.
///
/// \return True if the pointer is not null and the object is an instance of the given class, false otherwise.
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>>
bool isa(const BasePtr &ptr) {
if (ptr == nullptr) {
return false;
}
return ptr->isa<T>();
}
/// \brief Cast the given object pointer to a pointer with the given class.
///
/// \param[in] ptr The pointer to the object to casted.
///
/// \return A non-null pointer if the input pointer is not null and cast success, nullptr otherwise.
template <typename T, typename = typename std::enable_if_t<is_wrapper_ptr<T>::value, T>>
T cast(const BasePtr &ptr) {
if (ptr == nullptr) {
return nullptr;
}
return ptr->cast<T>();
}
/// \brief Make a copy from the given function graph.
///
/// \param[in] func_graph The graph to be cloned.
///
/// \return The cloned graph.
MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph);
/// \brief Get pad mode id from a value holds the pad mode name or id.
///
/// \param[in] value The value holds the pad mode name or id.
/// \param[in] is_upper Indicates whether the name is uppercase or lowercase, default is false for lowercase.
///
/// \return The pad mode id.
MIND_API int64_t GetPadMode(const ValuePtr &value, bool is_upper = false);
} // namespace mindspore::api::utils
#endif // MINDSPORE_CORE_MINDAPI_IR_UTILS_H_

View File

@ -0,0 +1,270 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
#define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
#include <vector>
#include <string>
#include <type_traits>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
namespace mindspore::api {
template <typename T>
struct ImmTrait {};
#define MIND_API_IMM_TRAIT(typeimm, prototype) \
template <> \
struct ImmTrait<prototype> { \
using type = SharedPtr<typeimm>; \
}
/// \brief Value represents a value in expression.
class MIND_API Value : public Base {
public:
MIND_API_BASE_MEMBER(Value);
/// \brief Get the type of this Value.
///
/// \return The type.
TypePtr type() const;
/// \brief Get the abstract of this Value.
///
/// \return Abstract of this Value.
AbstractBasePtr ToAbstract() const;
};
/// \brief ValueSequence represents a sequence of values.
class MIND_API ValueSequence : public Value {
public:
MIND_API_BASE_MEMBER(ValueSequence);
/// \brief Get the size of this ValueSequence.
///
/// \return The size as the number of elements.
std::size_t size() const;
/// \brief Get the list of values in this ValueSequence.
///
/// \return The list of element values.
std::vector<ValuePtr> value() const;
};
using ValueSequencePtr = SharedPtr<ValueSequence>;
/// \brief ValueTuple represents a value tuple.
class MIND_API ValueTuple : public ValueSequence {
public:
MIND_API_BASE_MEMBER(ValueTuple);
/// \brief Constructor of ValueTuple.
///
/// \param[in] elements The elements of the tuple.
explicit ValueTuple(const std::vector<ValuePtr> &elements);
};
using ValueTuplePtr = SharedPtr<ValueTuple>;
/// \brief StringImm defines a Value whose type is string.
class MIND_API StringImm : public Value {
public:
MIND_API_BASE_MEMBER(StringImm);
/// \brief Create StringImm with the given string.
///
/// \param[in] str The given string value.
explicit StringImm(const std::string &str);
/// \brief Get the string value of this StringImm.
///
/// \return The string value of this StringImm.
const std::string &value() const;
};
using StringImmPtr = SharedPtr<StringImm>;
MIND_API_IMM_TRAIT(StringImm, std::string);
/// \beief Scalar defines interface for scalar data.
class MIND_API Scalar : public Value {
public:
MIND_API_BASE_MEMBER(Scalar);
};
/// \beief BoolImm defines interface for bool data.
class MIND_API BoolImm : public Scalar {
public:
MIND_API_BASE_MEMBER(BoolImm);
/// \brief Create BoolImm with the given bool value.
///
/// \param[in] b The given bool value.
explicit BoolImm(bool b);
/// \brief Get the bool value of this BoolImm.
///
/// \return The bool value of this BoolImm.
bool value() const;
};
using BoolImmPtr = SharedPtr<BoolImm>;
MIND_API_IMM_TRAIT(BoolImm, bool);
/// \beief IntegerImm defines interface for integer data.
class MIND_API IntegerImm : public Scalar {
public:
MIND_API_BASE_MEMBER(IntegerImm);
};
/// \beief Int64Imm defines interface for int64 data.
class MIND_API Int64Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int64Imm);
/// \brief Create Int64Imm with the given int64 value.
///
/// \param[in] value The given bool value.
explicit Int64Imm(int64_t value);
/// \brief Get the int64 value of this Int64Imm.
///
/// \return The int64 value of this Int64Imm.
int64_t value() const;
};
using Int64ImmPtr = SharedPtr<Int64Imm>;
MIND_API_IMM_TRAIT(Int64Imm, int64_t);
/// \beief FloatImm defines interface for float data.
class MIND_API FloatImm : public Scalar {
public:
MIND_API_BASE_MEMBER(FloatImm);
};
/// \beief FP32Imm defines interface for float32 data.
class MIND_API FP32Imm : public FloatImm {
public:
MIND_API_BASE_MEMBER(FP32Imm);
/// \brief Create FP32Imm with the given float value.
///
/// \param[in] value The given float value.
explicit FP32Imm(float value);
/// \brief Get the float value of this FP32Imm.
///
/// \return The float value of this FP32Imm.
float value() const;
};
using FP32ImmPtr = SharedPtr<FP32Imm>;
MIND_API_IMM_TRAIT(FP32Imm, float);
// === Utility functions for Value === //
/// \brief brief Create a Value object from a primitive type value.
///
/// \param[in] v The primitive type value.
///
/// \return The created Value object with the given primitive type value.
template <typename T, typename U = typename ImmTrait<T>::type::element_type>
inline ValuePtr MakeValue(T v) {
return MakeShared<U>(v);
}
/// \brief brief Create a StringImm Value object from a C string.
///
/// \param[in] s The C string.
///
/// \return The created StringImm Value object.
inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }
/// \brief brief Create a Int64Imm Value object from a int value.
///
/// \param[in] i The int value.
///
/// \return The created Int64Imm Value object.
inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }
/// \brief brief Create a ValueSequence object from a vector of values.
///
/// \param[in] values The vector of values.
///
/// \return The created ValueSequence object.
inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }
/// \brief Create a ValueSequence object from a vector of primitive type values.
///
/// \param[in] values The vector of primitive values.
///
/// \return The created ValueSequence object.
template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
inline ValuePtr MakeValue(const T &values) {
std::vector<ValuePtr> value_vector;
value_vector.reserve(values.size());
for (auto &value : values) {
value_vector.emplace_back(MakeValue(value));
}
return MakeShared<ValueTuple>(value_vector);
}
/// \brief brief Get primitive type value from a Value object.
///
/// \param[in] value The pointer to the Value object.
///
/// \return The primitive type value of the Value object.
template <typename T, typename U = typename ImmTrait<T>::type>
inline T GetValue(const ValuePtr &value) {
if (value == nullptr) {
return T();
}
U imm = value->cast<U>();
if (imm == nullptr) {
return T();
}
return imm->value();
}
/// \brief brief Get primitive element values from a ValueSequeue object.
///
/// \param[in] value The pointer to the ValueSequeue object.
///
/// \return The primitive type values as a vector.
template <typename T, typename S = typename std::decay_t<T>,
typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
inline std::vector<U> GetValue(const ValuePtr &value) {
if (value == nullptr) {
return {};
}
auto seq = value->cast<ValueSequencePtr>();
if (seq == nullptr) {
return {};
}
auto elements = seq->value();
std::vector<U> result;
result.reserve(elements.size());
for (auto &e : elements) {
result.emplace_back(GetValue<U>(e));
}
return result;
}
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_

View File

@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/abstract.h"
#include "mindapi/src/helper.h"
#include "abstract/abstract_value.h"
#include "ir/dtype.h"
#include "ir/value.h"
namespace mindspore::api {
using TypeImpl = mindspore::Type;
using ValueImpl = mindspore::Value;
using AbstractBaseImpl = mindspore::abstract::AbstractBase;
MIND_API_BASE_IMPL(AbstractBase, AbstractBaseImpl, Base);
AbstractBasePtr AbstractBase::Clone() const {
auto abs = ToRef<AbstractBaseImpl>(impl_).Clone();
return ToWrapper<AbstractBase>(abs);
}
TypePtr AbstractBase::type() const {
auto t = ToRef<AbstractBaseImpl>(impl_).BuildType();
return ToWrapper<Type>(t);
}
ValuePtr AbstractBase::value() const {
auto v = ToRef<AbstractBaseImpl>(impl_).BuildValue();
return ToWrapper<Value>(v);
}
void AbstractBase::set_type(const TypePtr &type) {
auto type_impl = ToImpl<TypeImpl>(type);
ToRef<AbstractBaseImpl>(impl_).set_type(type_impl);
}
void AbstractBase::set_value(const ValuePtr &value) {
auto value_impl = ToImpl<ValueImpl>(value);
ToRef<AbstractBaseImpl>(impl_).set_value(value_impl);
}
using AbstractTensorImpl = mindspore::abstract::AbstractTensor;
MIND_API_BASE_IMPL(AbstractTensor, AbstractTensorImpl, AbstractBase);
AbstractTensor::AbstractTensor(TypeId type, const ShapeVector &shape)
: AbstractBase(std::make_shared<AbstractTensorImpl>(mindspore::TypeIdToType(type), shape)) {}
AbstractBasePtr AbstractTensor::element() const {
auto abs = ToRef<AbstractTensorImpl>(impl_).element();
return ToWrapper<AbstractBase>(abs);
}
ShapePtr AbstractTensor::shape() const {
auto s = ToRef<AbstractTensorImpl>(impl_).shape();
return ToWrapper<Shape>(s);
}
using AbstractSequenceImpl = mindspore::abstract::AbstractSequeue;
MIND_API_BASE_IMPL(AbstractSequence, AbstractSequenceImpl, AbstractBase);
AbstractBasePtrList AbstractSequence::elements() const {
auto &impl_elements = ToRef<AbstractSequenceImpl>(impl_).elements();
return ToWrapperVector<AbstractBase>(impl_elements);
}
using AbstractTupleImpl = mindspore::abstract::AbstractTuple;
MIND_API_BASE_IMPL(AbstractTuple, AbstractTupleImpl, AbstractSequence);
} // namespace mindspore::api

View File

@ -0,0 +1,135 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/anf.h"
#include "mindapi/src/helper.h"
#include "ir/anf.h"
#include "ir/value.h"
#include "ir/primitive.h"
#include "abstract/abstract_value.h"
namespace mindspore::api {
using ValueImpl = mindspore::Value;
using AnfNodeImpl = mindspore::AnfNode;
using PrimitiveImpl = mindspore::Primitive;
using AbstractBaseImpl = mindspore::abstract::AbstractBase;
MIND_API_BASE_IMPL(AnfNode, AnfNodeImpl, Base);
std::string AnfNode::fullname_with_scope() const { return ToRef<AnfNodeImpl>(impl_).fullname_with_scope(); }
AbstractBasePtr AnfNode::abstract() const {
const auto &abs = ToRef<AnfNodeImpl>(impl_).abstract();
return ToWrapper<AbstractBase>(abs);
}
void AnfNode::set_abstract(const AbstractBasePtr &abs) {
ToRef<AnfNodeImpl>(impl_).set_abstract(ToImpl<AbstractBaseImpl>(abs));
}
using CNodeImpl = mindspore::CNode;
MIND_API_BASE_IMPL(CNode, CNodeImpl, AnfNode);
size_t CNode::size() const { return ToRef<CNodeImpl>(impl_).size(); }
AnfNodePtr CNode::input(size_t i) const {
auto &input = ToRef<CNodeImpl>(impl_).input(i);
return ToWrapper<AnfNode>(input);
}
std::vector<AnfNodePtr> CNode::inputs() const {
auto &impl_inputs = ToRef<CNodeImpl>(impl_).inputs();
return ToWrapperVector<AnfNode>(impl_inputs);
}
void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) {
auto impl_inputs = ToImplVector<AnfNodeImpl>(inputs);
ToRef<CNodeImpl>(impl_).set_inputs(impl_inputs);
}
void CNode::add_input(const AnfNodePtr &input) {
auto impl_input = ToImpl<AnfNodeImpl>(input);
MS_EXCEPTION_IF_NULL(impl_input);
ToRef<CNodeImpl>(impl_).add_input(impl_input);
}
void CNode::set_fullname_with_scope(const std::string &full_name) {
ToRef<CNodeImpl>(impl_).set_fullname_with_scope(full_name);
}
void CNode::AddAttr(const std::string &name, const ValuePtr &attr) {
auto impl_attr = ToImpl<ValueImpl>(attr);
MS_EXCEPTION_IF_NULL(impl_attr);
ToRef<CNodeImpl>(impl_).AddAttr(name, impl_attr);
}
void CNode::EraseAttr(const std::string &name) { ToRef<CNodeImpl>(impl_).EraseAttr(name); }
ValuePtr CNode::GetAttr(const std::string &name) const {
auto v = ToRef<CNodeImpl>(impl_).GetAttr(name);
return ToWrapper<Value>(v);
}
using ParameterImpl = mindspore::Parameter;
MIND_API_BASE_IMPL(Parameter, ParameterImpl, AnfNode);
std::string Parameter::name() const { return ToRef<ParameterImpl>(impl_).name(); }
void Parameter::set_name(const std::string &name) { ToRef<ParameterImpl>(impl_).set_name(name); }
bool Parameter::has_default() const { return ToRef<ParameterImpl>(impl_).has_default(); }
void Parameter::set_default_param(const ValuePtr &param) {
auto v = ToImpl<ValueImpl>(param);
ToRef<ParameterImpl>(impl_).set_default_param(v);
}
ValuePtr Parameter::default_param() const {
auto v = ToRef<ParameterImpl>(impl_).default_param();
return ToWrapper<Value>(v);
}
using ValueNodeImpl = mindspore::ValueNode;
MIND_API_BASE_IMPL(ValueNode, ValueNodeImpl, AnfNode);
ValueNode::ValueNode(const ValuePtr &value) : AnfNode(std::make_shared<ValueNodeImpl>(ToImpl<ValueImpl>(value))) {}
ValuePtr ValueNode::value() const {
auto v = ToRef<ValueNodeImpl>(impl_).value();
return ToWrapper<Value>(v);
}
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
auto prim_impl = ToImpl<PrimitiveImpl>(prim);
return mindspore::IsPrimitiveCNode(node_impl, prim_impl);
}
bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
auto prim_impl = ToImpl<PrimitiveImpl>(prim);
return mindspore::IsPrimitive(node_impl, prim_impl);
}
bool IsDataNode(const AnfNodePtr &node) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
// We assume that node with monad abstract is not a data node.
return !HasAbstractMonad(node_impl);
}
} // namespace mindspore::api

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/base/base.h"
#include "base/base.h"
namespace mindspore::api {
Base::Base(const std::shared_ptr<mindspore::Base> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl_); }
uint32_t Base::ClassId() { return mindspore::Base::kTypeId; }
bool Base::IsFromClassId(uint32_t class_id) const { return impl_->IsFromTypeId(class_id); }
std::string Base::ToString() const { return impl_->ToString(); }
} // namespace mindspore::api

View File

@ -0,0 +1,170 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "mindapi/ir/func_graph.h"
#include "mindapi/src/helper.h"
#define USE_DEPRECATED_API
#include "ir/anf.h"
#include "ir/value.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "ir/primitive.h"
#include "ir/graph_utils.h"
namespace mindspore::api {
using ValueImpl = mindspore::Value;
using AnfNodeImpl = mindspore::AnfNode;
using CNodeImpl = mindspore::CNode;
using PrimitiveImpl = mindspore::Primitive;
using ParameterImpl = mindspore::Parameter;
using FuncGraphImpl = mindspore::FuncGraph;
using FuncGraphManagerImpl = mindspore::FuncGraphManager;
MIND_API_BASE_IMPL(FuncGraph, FuncGraphImpl, Value);
std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
auto &inputs = ToRef<FuncGraphImpl>(impl_).get_inputs();
return ToWrapperVector<AnfNode>(inputs);
}
std::vector<AnfNodePtr> FuncGraph::parameters() const {
auto &params = ToRef<FuncGraphImpl>(impl_).parameters();
return ToWrapperVector<AnfNode>(params);
}
void FuncGraph::add_parameter(const ParameterPtr &p) {
auto param_impl = ToImpl<ParameterImpl>(p);
ToRef<FuncGraphImpl>(impl_).add_parameter(param_impl);
}
ParameterPtr FuncGraph::add_parameter() {
auto param_impl = ToRef<FuncGraphImpl>(impl_).add_parameter();
return ToWrapper<Parameter>(param_impl);
}
AnfNodePtr FuncGraph::output() const {
auto output = ToRef<FuncGraphImpl>(impl_).output();
return ToWrapper<AnfNode>(output);
}
CNodePtr FuncGraph::get_return() const {
auto ret = ToRef<FuncGraphImpl>(impl_).get_return();
return ToWrapper<CNode>(ret);
}
void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
auto output = ToImpl<AnfNodeImpl>(value);
ToRef<FuncGraphImpl>(impl_).set_output(output);
}
void FuncGraph::set_return(const CNodePtr &cnode) {
auto cnode_impl = ToImpl<CNodeImpl>(cnode);
ToRef<FuncGraphImpl>(impl_).set_return(cnode_impl);
}
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
auto inputs_impl = ToImplVector<AnfNodeImpl>(inputs);
auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(std::move(inputs_impl));
return ToWrapper<CNode>(cnode_impl);
}
CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) {
auto prim_impl = ToImpl<PrimitiveImpl>(primitive);
auto prim_inputs_impl = ToImplVector<AnfNodeImpl>(prim_inputs);
auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(prim_impl, prim_inputs_impl);
return ToWrapper<CNode>(cnode_impl);
}
std::vector<AnfNodePtr> FuncGraph::nodes() const {
auto &nodes = ToRef<FuncGraphImpl>(impl_).nodes();
return ToWrapperVector<AnfNode>(nodes);
}
bool FuncGraph::has_attr(const std::string &key) const { return ToRef<FuncGraphImpl>(impl_).has_attr(key); }
ValuePtr FuncGraph::get_attr(const std::string &key) const {
auto v = ToRef<FuncGraphImpl>(impl_).get_attr(key);
return ToWrapper<Value>(v);
}
void FuncGraph::set_attr(const std::string &key, const ValuePtr &value) {
auto value_impl = ToImpl<ValueImpl>(value);
ToRef<FuncGraphImpl>(impl_).set_attr(key, value_impl);
}
FuncGraphManagerPtr FuncGraph::manager() const {
auto manager = ToRef<FuncGraphImpl>(impl_).manager();
if (manager == nullptr) {
return nullptr;
}
return MakeShared<FuncGraphManager>(manager);
}
FuncGraphPtr FuncGraph::Create() {
auto fg = std::make_shared<FuncGraphImpl>();
return ToWrapper<FuncGraph>(fg);
}
std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
if (node_impl == nullptr) {
return {};
}
auto sorted = mindspore::TopoSort(node_impl);
return ToWrapperVector<AnfNode>(sorted);
}
// FuncGraphManager is not derived from Base, we implement it directly.
FuncGraphManager::FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl) : impl_(impl) {
MS_EXCEPTION_IF_NULL(impl_);
}
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
return impl_->Replace(ToImpl<AnfNodeImpl>(old_node), ToImpl<AnfNodeImpl>(new_node));
}
void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
return impl_->SetEdge(ToImpl<AnfNodeImpl>(node), index, ToImpl<AnfNodeImpl>(value));
}
void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
return impl_->AddEdge(ToImpl<AnfNodeImpl>(node), ToImpl<AnfNodeImpl>(value));
}
std::vector<std::pair<AnfNodePtr, int>> FuncGraphManager::GetUsers(const AnfNodePtr &node) const {
auto &node_users = impl_->node_users();
auto iter = node_users.find(ToImpl<AnfNodeImpl>(node));
if (iter == node_users.end()) {
return {};
}
auto &users_impl = iter->second;
std::vector<std::pair<AnfNodePtr, int>> users;
users.reserve(users_impl.size());
std::transform(users_impl.begin(), users_impl.end(), std::back_inserter(users),
[](const auto &user) { return std::make_pair(ToWrapper<AnfNode>(user.first), user.second); });
return users;
}
FuncGraphManagerPtr FuncGraphManager::Manage(const FuncGraphPtr &func_graph, bool manage) {
auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
auto mgr_impl = mindspore::Manage(fg_impl, manage);
if (mgr_impl == nullptr) {
return nullptr;
}
return MakeShared<FuncGraphManager>(mgr_impl);
}
} // namespace mindspore::api

View File

@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_
#define MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_
#include <memory>
#include <vector>
#include <type_traits>
#include "mindapi/base/base.h"
namespace mindspore::api {
template <typename T, typename U>
T &ToRef(const std::shared_ptr<U> &ptr) {
return static_cast<T &>(*ptr);
}
template <typename T, typename U, typename = typename std::enable_if_t<std::is_base_of_v<mindspore::Base, T>>,
typename = typename std::enable_if_t<std::is_base_of_v<Base, U>>>
std::shared_ptr<T> ToImpl(const SharedPtr<U> &wrapper) {
if (wrapper == nullptr || wrapper->impl() == nullptr) {
return nullptr;
}
return std::dynamic_pointer_cast<T>(wrapper->impl());
}
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>>>
SharedPtr<T> ToWrapper(const std::shared_ptr<mindspore::Base> &impl) {
if (impl == nullptr) {
return nullptr;
}
return MakeShared<T>(impl);
}
template <typename T, typename U>
std::vector<std::shared_ptr<T>> ToImplVector(const U &wrapper_vector) {
std::vector<std::shared_ptr<T>> impl_vector;
impl_vector.reserve(wrapper_vector.size());
for (auto &wrapper : wrapper_vector) {
impl_vector.emplace_back(ToImpl<T>(wrapper));
}
return impl_vector;
}
template <typename T, typename U>
std::vector<SharedPtr<T>> ToWrapperVector(const U &impl_vector) {
std::vector<SharedPtr<T>> wrapper_vector;
wrapper_vector.reserve(impl_vector.size());
for (auto &impl : impl_vector) {
wrapper_vector.emplace_back(ToWrapper<T>(impl));
}
return wrapper_vector;
}
#define MIND_API_BASE_IMPL(current_class, impl_class, base_class) \
current_class::current_class(const std::shared_ptr<mindspore::Base> &impl) : base_class(impl) { \
if (!impl_->isa<impl_class>()) { \
MS_LOG(EXCEPTION) << "Wrong impl " << impl_->type_name() << " for " << #current_class; \
} \
} \
uint32_t current_class::ClassId() { return impl_class::kTypeId; }
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_

View File

@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define MIND_LOG_NO_MS_LOG
#include "mindapi/base/logging.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
static MsLogLevel ToMsLogLevel(LogLevel level) {
switch (level) {
case LogLevel::DEBUG:
return MsLogLevel::DEBUG;
case LogLevel::INFO:
return MsLogLevel::INFO;
case LogLevel::WARNING:
return MsLogLevel::WARNING;
case LogLevel::ERROR:
return MsLogLevel::ERROR;
case LogLevel::EXCEPTION:
return MsLogLevel::EXCEPTION;
default:
return MsLogLevel::EXCEPTION;
}
}
class LogWriterImpl {
public:
LogWriterImpl(LogLevel level, const char *file, int line, const char *func)
: writer_(LocationInfo(file, line, func), ToMsLogLevel(level), SubModuleId::SM_API) {}
~LogWriterImpl() = default;
void Write(const LogStream &stream) const noexcept {
mindspore::LogStream log_stream;
log_stream << stream.stream_.rdbuf();
writer_ < log_stream;
}
void WriteAndThrow(const LogStream &stream) const __attribute__((noreturn)) {
mindspore::LogStream log_stream;
log_stream << stream.stream_.rdbuf();
writer_ ^ log_stream;
}
private:
mindspore::LogWriter writer_;
};
LogWriter::LogWriter(LogLevel level, const char *file, int line, const char *func)
: impl_(std::make_unique<LogWriterImpl>(level, file, line, func)) {}
LogWriter::~LogWriter() = default;
void LogWriter::operator<(const LogStream &stream) const noexcept { impl_->Write(stream); }
void LogWriter::operator^(const LogStream &stream) const { impl_->WriteAndThrow(stream); }
bool LogWriter::IsEnabled(LogLevel level) {
auto log_level = ToMsLogLevel(level);
return IS_OUTPUT_ON(log_level);
}
} // namespace mindspore::api

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/primitive.h"
#include "mindapi/src/helper.h"
#include "ir/primitive.h"
#include "ir/value.h"
namespace mindspore::api {
using ValueImpl = mindspore::Value;
using PrimitiveImpl = mindspore::Primitive;
MIND_API_BASE_IMPL(Primitive, PrimitiveImpl, Value);
Primitive::Primitive(const std::string &name) : Value(std::make_shared<PrimitiveImpl>(name)) {}
const std::string &Primitive::name() const { return ToRef<PrimitiveImpl>(impl_).name(); }
Primitive &Primitive::AddAttr(const std::string &name, const ValuePtr &attr) {
auto value = ToImpl<ValueImpl>(attr);
ToRef<PrimitiveImpl>(impl_).set_attr(name, value);
return *this;
}
Primitive &Primitive::SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
auto value = ToImpl<ValueImpl>(attr.second);
ToRef<PrimitiveImpl>(impl_).set_attr(attr.first, value);
}
return *this;
}
void Primitive::EraseAttr(const std::string &name) { ToRef<PrimitiveImpl>(impl_).EraseAttr(name); }
ValuePtr Primitive::GetAttr(const std::string &name) const {
auto v = ToRef<PrimitiveImpl>(impl_).GetAttr(name);
return ToWrapper<Value>(v);
}
bool Primitive::HasAttr(const std::string &name) const { return ToRef<PrimitiveImpl>(impl_).HasAttr(name); }
std::unordered_map<std::string, ValuePtr> Primitive::attrs() const {
std::unordered_map<std::string, ValuePtr> attr_map;
auto &impl_attrs = ToRef<PrimitiveImpl>(impl_).attrs();
attr_map.reserve(impl_attrs.size());
for (auto &attr : impl_attrs) {
auto value = ToWrapper<Value>(attr.second);
attr_map.emplace(attr.first, value);
}
return attr_map;
}
} // namespace mindspore::api

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/shape.h"
#include "mindapi/src/helper.h"
#include "abstract/dshape.h"
namespace mindspore::api {
using ShapeImpl = mindspore::abstract::Shape;
MIND_API_BASE_IMPL(Shape, ShapeImpl, Base);
const ShapeVector &Shape::shape() const { return ToRef<ShapeImpl>(impl_).shape(); }
} // namespace mindspore::api

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "mindapi/ir/tensor.h"
#include "mindapi/src/helper.h"
#include "ir/tensor.h"
namespace mindspore::api {
using TensorImpl = mindspore::tensor::Tensor;
MIND_API_BASE_IMPL(Tensor, TensorImpl, Value);
Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : Value(std::make_shared<TensorImpl>(data_type, shape)) {}
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
: Value(std::make_shared<TensorImpl>(data_type, shape, data, data_len)) {}
const ShapeVector &Tensor::shape() const { return ToRef<TensorImpl>(impl_).shape(); }
void Tensor::set_shape(const ShapeVector &shape) { (void)ToRef<TensorImpl>(impl_).set_shape(shape); }
TypeId Tensor::data_type() const { return ToRef<TensorImpl>(impl_).data_type(); }
void Tensor::set_data_type(const TypeId data_type) { (void)ToRef<TensorImpl>(impl_).set_data_type(data_type); }
const void *Tensor::data() const { return ToRef<TensorImpl>(impl_).data_c(); }
void *Tensor::data() { return ToRef<TensorImpl>(impl_).data_c(); }
int Tensor::DataSize() const { return ToRef<TensorImpl>(impl_).DataSize(); }
size_t Tensor::Size() const { return ToRef<TensorImpl>(impl_).Size(); }
} // namespace mindspore::api

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/type.h"
#include "mindapi/ir/value.h"
#include "mindapi/src/helper.h"
#include "ir/dtype/type.h"
#include "ir/dtype.h"
#include "abstract/utils.h"
namespace mindspore::api {
using TypeImpl = mindspore::Type;
MIND_API_BASE_IMPL(Type, TypeImpl, Value);
TypeId Type::type_id() const { return ToRef<TypeImpl>(impl_).type_id(); }
TypeId Type::number_type() const { return ToRef<TypeImpl>(impl_).number_type(); }
TypePtr Type::GetType(TypeId id) {
auto type_impl = mindspore::TypeIdToType(id);
return ToWrapper<Type>(type_impl);
}
size_t Type::GetSize(TypeId id) { return mindspore::abstract::TypeIdSize(id); }
} // namespace mindspore::api

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/utils.h"
#include "mindapi/src/helper.h"
#define USE_DEPRECATED_API
#include "ir/anf.h"
#include "ir/value.h"
#include "ir/func_graph_cloner.h"
#include "utils/check_convert_utils.h"
namespace mindspore::api::utils {
using ValueImpl = mindspore::Value;
using FuncGraphImpl = mindspore::FuncGraph;
MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
Cloner cloner({fg_impl}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
auto cloned_fg = cloner[fg_impl];
return ToWrapper<api::FuncGraph>(cloned_fg);
}
int64_t GetPadMode(const api::ValuePtr &value, bool is_upper) {
int64_t result;
auto value_impl = ToImpl<ValueImpl>(value);
CheckAndConvertUtils::GetPadModEnumValue(value_impl, &result, is_upper);
return result;
}
} // namespace mindspore::api::utils

View File

@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindapi/ir/value.h"
#include "mindapi/ir/type.h"
#include "mindapi/ir/abstract.h"
#include "mindapi/src/helper.h"
#include "abstract/abstract_value.h"
#include "ir/anf.h"
#include "ir/dtype/type.h"
#include "ir/value.h"
#include "ir/scalar.h"
namespace mindspore::api {
using ValueImpl = mindspore::Value;
using ValueSequenceImpl = mindspore::ValueSequeue; // 'Sequeue' is typo.
using ValueTupleImpl = mindspore::ValueTuple;
using StringImmImpl = mindspore::StringImm;
using ScalarImpl = mindspore::Scalar;
using BoolImmImpl = mindspore::BoolImm;
using IntegerImmImpl = mindspore::IntergerImm; // 'Interger' is typo.
using Int64ImmImpl = mindspore::Int64Imm;
using FloatImmImpl = mindspore::FloatImm;
using FP32ImmImpl = mindspore::FP32Imm;
MIND_API_BASE_IMPL(Value, ValueImpl, Base);
TypePtr Value::type() const {
auto t = ToRef<ValueImpl>(impl_).type();
return ToWrapper<Type>(t);
}
AbstractBasePtr Value::ToAbstract() const {
auto abs = ToRef<ValueImpl>(impl_).ToAbstract();
return ToWrapper<AbstractBase>(abs);
}
MIND_API_BASE_IMPL(ValueSequence, ValueSequenceImpl, Value);
std::size_t ValueSequence::size() const { return ToRef<ValueSequenceImpl>(impl_).size(); }
std::vector<ValuePtr> ValueSequence::value() const {
auto &elements = ToRef<ValueSequenceImpl>(impl_).value();
return ToWrapperVector<Value>(elements);
}
MIND_API_BASE_IMPL(ValueTuple, ValueTupleImpl, ValueSequence);
ValueTuple::ValueTuple(const std::vector<ValuePtr> &elements)
: ValueSequence(std::make_shared<ValueTupleImpl>(ToImplVector<ValueImpl>(elements))) {}
MIND_API_BASE_IMPL(StringImm, StringImmImpl, Value);
StringImm::StringImm(const std::string &str) : Value(std::make_shared<StringImmImpl>(str)) {}
const std::string &StringImm::value() const { return ToRef<StringImmImpl>(impl_).value(); }
MIND_API_BASE_IMPL(Scalar, ScalarImpl, Value);
MIND_API_BASE_IMPL(BoolImm, BoolImmImpl, Scalar);
BoolImm::BoolImm(bool b) : Scalar(std::make_shared<BoolImmImpl>(b)) {}
bool BoolImm::value() const { return ToRef<BoolImmImpl>(impl_).value(); }
MIND_API_BASE_IMPL(IntegerImm, IntegerImmImpl, Scalar);
MIND_API_BASE_IMPL(Int64Imm, Int64ImmImpl, IntegerImm);
Int64Imm::Int64Imm(int64_t value) : IntegerImm(std::make_shared<Int64ImmImpl>(value)) {}
int64_t Int64Imm::value() const { return ToRef<Int64ImmImpl>(impl_).value(); }
MIND_API_BASE_IMPL(FloatImm, FloatImmImpl, Scalar);
MIND_API_BASE_IMPL(FP32Imm, FP32ImmImpl, FloatImm);
FP32Imm::FP32Imm(float value) : FloatImm(std::make_shared<FP32ImmImpl>(value)) {}
float FP32Imm::value() const { return ToRef<FP32ImmImpl>(impl_).value(); }
} // namespace mindspore::api

View File

@ -141,6 +141,7 @@ enum SubModuleId : int {
SM_HCCL_ADPT, // Hccl Adapter
SM_RUNTIME_FRAMEWORK, // Runtime framework
SM_GE, // GraphEngine
SM_API, // MindAPI
NUM_SUBMODUES // number of submodules
};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,7 +17,6 @@
#ifndef MINDSPORE_SHAPE_UTILS_INFO_H_
#define MINDSPORE_SHAPE_UTILS_INFO_H_
#include <vector>
using ShapeVector = std::vector<int64_t>;
#include "mindapi/base/shape_vector.h"
#endif // MINDSPORE_SHAPE_UTILS_INFO_H_

View File

@ -16,6 +16,8 @@ set(API_IR_HEADER
${CORE_DIR}/api/ir/func_graph.h
${CORE_DIR}/api/ir/func_graph_manager.h
)
file(GLOB MINDAPI_BASE_HEADER ${CORE_DIR}/mindapi/base/*.h)
file(GLOB MINDAPI_IR_HEADER ${CORE_DIR}/mindapi/ir/*.h)
set(BASE_HEADER
${CORE_DIR}/base/base.h
${CORE_DIR}/base/base_ref.h

View File

@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \
-I . \
-I ./msl/runtime \
-I ./msl/runtime/include \
-I ./msl/runtime/minddata \
-I ./msl/tools/third_party/flatbuffers/include

View File

@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \
-I . \
-I ./msl/runtime \
-I ./msl/runtime/include \
-I ./msl/runtime/minddata \
-I ./msl/tools/third_party/flatbuffers/include

View File

@ -19,6 +19,7 @@ INF_OBJ:=$(INF_SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \
-I . \
-I ./msl/runtime \
-I ./msl/runtime/include \
-I ./msl/runtime/minddata \
-I ./msl/tools/third_party/flatbuffers/include

View File

@ -289,6 +289,9 @@ if(APPLE)
set(MINDSPORE_LITE_PUB_HDRS_IR_HDRS
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type_id.h
)
set(MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS
${CMAKE_CURRENT_SOURCE_DIR}/../../core/mindapi/base/type_id.h
)
add_library(mindspore-lite_static STATIC
${LITE_SRC}
${MINDSPORE_LITE_PUB_HDRS}
@ -423,6 +426,9 @@ if(DEFINED ARCHS)
FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_IR_HDRS})
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/ir/dtype/)
ENDFOREACH()
FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS})
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/mindapi/base/)
ENDFOREACH()
target_link_libraries(mindspore-lite_static)
endif()

View File

@ -71,6 +71,7 @@ if(ENABLE_MINDDATA)
./fl/*.cc
./cxx_api/*.cc
./tbe/*.cc
./mindapi/*.cc
)
if(NOT ENABLE_SECURITY)
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}

View File

@ -0,0 +1,394 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <memory>
#include <sstream>
#include <unordered_map>
#include "common/common_test.h"
#include "mindapi/base/logging.h"
#include "mindapi/ir/func_graph.h"
#include "mindapi/ir/tensor.h"
#include "mindapi/ir/utils.h"
namespace mindspore::api {
class TestMindApi : public UT::Common {
public:
TestMindApi() = default;
};
/// Feature: MindAPI
/// Description: test basic 'is()' 'cast()'
/// Expectation: is/cast works correctly.
TEST_F(TestMindApi, test_base_isa_cast) {
auto value_node = MakeShared<ValueNode>(MakeValue(0));
auto base = MakeShared<Base>(value_node->impl());
ASSERT_TRUE(base->isa<Base>());
ASSERT_TRUE(base->isa<AnfNode>());
ASSERT_TRUE(base->isa<ValueNode>());
ASSERT_FALSE(base->isa<AbstractBase>());
auto anf_node = base->cast<AnfNodePtr>();
ASSERT_TRUE(anf_node != nullptr);
ASSERT_TRUE(anf_node->impl() == value_node->impl());
ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr);
}
/// Feature: MindAPI
/// Description: test graph construction.
/// Expectation: graph is constructed as expected.
TEST_F(TestMindApi, test_graph_construction) {
// fg(x) { return myprim(x, 1); }
auto fg = FuncGraph::Create();
auto x = fg->add_parameter();
x->set_name("x");
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto value_node = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({prim_node, x, value_node});
fg->set_output(cnode);
// Now we check the graph.
ASSERT_EQ(fg->parameters().size(), 1);
ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>());
ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x");
auto ret_node = fg->get_return();
ASSERT_TRUE(ret_node != nullptr);
auto output_node = fg->output();
ASSERT_TRUE(output_node != nullptr);
ASSERT_TRUE(output_node->isa<CNode>());
auto output_cnode = output_node->cast<CNodePtr>();
ASSERT_EQ(output_cnode->inputs().size(), 3);
ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>());
ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>());
ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim");
ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>());
ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x");
ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>());
ASSERT_EQ(output_cnode->impl(), cnode->impl());
}
/// Feature: MindAPI
/// Description: test value related functions.
/// Expectation: value related functions work as expected.
TEST_F(TestMindApi, test_values) {
int64_t one = 1;
auto s = MakeValue("hello");
auto i = MakeValue(one);
auto i2 = MakeValue(2);
auto b = MakeValue(true);
auto f = MakeValue(3.14f);
auto seq = MakeValue(std::vector<int64_t>{3, 4, 5});
auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"}));
ASSERT_TRUE(s->isa<StringImm>());
ASSERT_TRUE(i->isa<Int64Imm>());
ASSERT_TRUE(i2->isa<Int64Imm>());
ASSERT_TRUE(b->isa<BoolImm>());
ASSERT_TRUE(f->isa<FP32Imm>());
ASSERT_TRUE(seq->isa<ValueSequence>());
ASSERT_TRUE(seq_str->isa<ValueSequence>());
ASSERT_EQ(GetValue<std::string>(s), "hello");
ASSERT_EQ(GetValue<int64_t>(i), one);
ASSERT_EQ(GetValue<int64_t>(i2), 2);
ASSERT_TRUE(GetValue<bool>(b));
ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f);
ASSERT_EQ(GetValue<std::string>(i), "");
ASSERT_EQ(GetValue<int64_t>(s), 0);
ASSERT_FALSE(GetValue<bool>(s));
ASSERT_EQ(GetValue<float>(s), 0.0f);
auto seq_ptr = seq->cast<ValueSequencePtr>();
ASSERT_TRUE(seq_ptr != nullptr);
ASSERT_EQ(seq_ptr->size(), 3);
ASSERT_EQ(seq_ptr->value().size(), 3);
ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>());
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3);
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4);
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5);
auto seq_values = GetValue<std::vector<int64_t>>(seq);
ASSERT_EQ(seq_values.size(), 3);
ASSERT_EQ(seq_values[0], 3);
ASSERT_EQ(seq_values[1], 4);
ASSERT_EQ(seq_values[2], 5);
auto str_values = GetValue<std::vector<std::string>>(seq_str);
ASSERT_EQ(str_values.size(), 4);
ASSERT_EQ(str_values[0], "this");
ASSERT_EQ(str_values[1], "is");
ASSERT_EQ(str_values[2], "mindspore");
ASSERT_EQ(str_values[3], "api");
}
/// Feature: MindAPI
/// Description: test graph manager functions.
/// Expectation: graph manager functions work as expected.
TEST_F(TestMindApi, test_func_graph_manager) {
// fg(x, y) { return myprim(add(x, y), 1); }
auto fg = FuncGraph::Create();
auto x = fg->add_parameter();
x->set_name("x");
auto y = fg->add_parameter();
y->set_name("y");
auto add = MakeShared<Primitive>("add");
auto add_node = MakeShared<ValueNode>(add);
auto add_cnode = fg->NewCNode({add_node, x, y});
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto value_node = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
fg->set_output(cnode);
auto mgr = FuncGraphManager::Manage(fg);
ASSERT_TRUE(mgr != nullptr);
ASSERT_TRUE(fg->manager() != nullptr);
ASSERT_EQ(fg->manager()->impl(), mgr->impl());
ASSERT_EQ(fg->manager(), mgr);
ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl());
mgr->Replace(add_cnode, x);
ASSERT_EQ(cnode->input(1)->impl(), x->impl());
mgr->SetEdge(cnode, 1, y);
ASSERT_EQ(cnode->input(1)->impl(), y->impl());
mgr->AddEdge(cnode, x);
ASSERT_EQ(cnode->size(), 4);
ASSERT_EQ(cnode->input(3)->impl(), x->impl());
auto users = mgr->GetUsers(value_node);
ASSERT_EQ(users.size(), 1);
ASSERT_EQ(users[0].first, cnode);
ASSERT_EQ(users[0].second, 2);
}
/// Feature: MindAPI
/// Description: test value node utils.
/// Expectation: value node utils work as expected.
TEST_F(TestMindApi, test_value_node_utils) {
auto fg = FuncGraph::Create();
auto fg_node = MakeShared<ValueNode>(fg);
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto one = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({fg_node, prim_node, one});
ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr);
auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0));
ASSERT_TRUE(fg1 != nullptr);
ASSERT_TRUE(fg1->isa<FuncGraph>());
auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1));
ASSERT_TRUE(prim1 != nullptr);
ASSERT_TRUE(prim1->isa<Primitive>());
auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2));
ASSERT_TRUE(imm != nullptr);
ASSERT_TRUE(imm->isa<Int64Imm>());
ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1);
auto value = GetValueNode(cnode->input(2));
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(GetValue<int64_t>(value), 1);
ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr);
ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr);
ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr);
// Test NewValueNode.
auto int_node = NewValueNode(1);
auto bool_node = NewValueNode(true);
auto float_node = NewValueNode(1.23f);
auto str_node = NewValueNode("hello");
ASSERT_TRUE(int_node->value()->isa<Int64Imm>());
ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1);
ASSERT_TRUE(bool_node->value()->isa<BoolImm>());
ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value());
ASSERT_TRUE(float_node->value()->isa<FP32Imm>());
ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f);
ASSERT_TRUE(str_node->value()->isa<StringImm>());
ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello");
}
/// Feature: MindAPI
/// Description: test SharedPtr.
/// Expectation: SharedPtr work as expected.
TEST_F(TestMindApi, test_object_ptr) {
auto fg = FuncGraph::Create();
auto fg_node = MakeShared<ValueNode>(fg);
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto one = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({fg_node, prim_node, one});
ASSERT_TRUE(fg != nullptr);
ASSERT_FALSE(!fg);
ASSERT_TRUE(fg ? true : false);
ASSERT_TRUE((*cnode).input(0) == fg_node);
ASSERT_TRUE(cnode->input(0) == fg_node);
ASSERT_TRUE(cnode.get()->input(0) == fg_node);
ASSERT_EQ(cnode->input(0), fg_node);
ASSERT_EQ(cnode->input(1), prim_node);
ASSERT_EQ(cnode->input(2), one);
ASSERT_TRUE(cnode->input(0) != fg);
AnfNodePtr p = fg_node;
ASSERT_TRUE(p == fg_node);
ASSERT_TRUE(p->isa<ValueNode>());
ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr);
ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node);
p = cnode;
ASSERT_TRUE(p == cnode);
ASSERT_TRUE(p->isa<CNode>());
ASSERT_TRUE(p->cast<CNodePtr>() != nullptr);
ASSERT_TRUE(p->cast<CNodePtr>() == cnode);
ASSERT_TRUE(p.get() == cnode.get());
ASSERT_TRUE(p != nullptr);
ASSERT_FALSE(p == nullptr);
ASSERT_TRUE(p > nullptr);
ASSERT_FALSE(p < nullptr);
ASSERT_TRUE(p >= nullptr);
ASSERT_FALSE(p <= nullptr);
ASSERT_TRUE(nullptr != p);
ASSERT_FALSE(nullptr == p);
ASSERT_TRUE(nullptr < p);
ASSERT_FALSE(nullptr > p);
ASSERT_TRUE(nullptr <= p);
ASSERT_FALSE(nullptr >= p);
AnfNodePtr q = fg_node;
ASSERT_TRUE(p != q);
ASSERT_TRUE(p > q);
if (p.get()->impl() > q.get()->impl()) {
ASSERT_TRUE(p > q);
ASSERT_TRUE(p >= q);
ASSERT_TRUE(q < p);
ASSERT_TRUE(q <= p);
} else {
ASSERT_TRUE(p < q);
ASSERT_TRUE(p <= q);
ASSERT_TRUE(q > p);
ASSERT_TRUE(q >= p);
}
std::stringstream ss1;
std::stringstream ss2;
ss1 << p;
ss2 << cnode.get()->impl().get();
ASSERT_EQ(ss1.str(), ss2.str());
std::unordered_map<AnfNodePtr, AnfNodePtr> mymap;
mymap.emplace(p, q);
mymap.emplace(q, p);
ASSERT_TRUE(mymap.find(p) != mymap.end());
ASSERT_TRUE(mymap.find(q) != mymap.end());
ASSERT_TRUE(mymap[p] == q);
ASSERT_TRUE(mymap[q] == p);
}
/// Feature: MindAPI
/// Description: test Tensor API.
/// Expectation: Tensor API work as expected.
TEST_F(TestMindApi, test_tensor_api) {
ShapeVector shape{1, 2, 3};
auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
ASSERT_EQ(tensor->shape(), shape);
ASSERT_EQ(tensor->DataSize(), 6);
ASSERT_EQ(tensor->Size(), 24);
ShapeVector shape2{2, 3};
tensor->set_data_type(kNumberTypeInt32);
tensor->set_shape(shape2);
ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
ASSERT_EQ(tensor->shape(), shape2);
}
/// Feature: MindAPI
/// Description: test utils API.
/// Expectation: Tensor API work as expected.
TEST_F(TestMindApi, test_api_utils) {
// Test utils::isa, utils::cast.
auto anf_node = NewValueNode("hello");
ASSERT_TRUE(utils::isa<AnfNode>(anf_node));
ASSERT_FALSE(utils::isa<AbstractBase>(anf_node));
ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr);
ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr);
anf_node = nullptr;
ASSERT_FALSE(utils::isa<AnfNode>(anf_node));
ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr);
// Test clone graph.
auto fg = FuncGraph::Create();
auto x = fg->add_parameter();
x->set_name("x");
auto y = fg->add_parameter();
y->set_name("y");
auto add = MakeShared<Primitive>("add");
auto add_node = MakeShared<ValueNode>(add);
auto add_cnode = fg->NewCNode({add_node, x, y});
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto value_node = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
fg->set_output(cnode);
auto cloned_fg = utils::CloneGraph(fg);
ASSERT_TRUE(cloned_fg != nullptr);
ASSERT_EQ(cloned_fg->parameters().size(), 2);
auto new_output = cloned_fg->output();
ASSERT_TRUE(new_output != nullptr);
ASSERT_TRUE(new_output->isa<CNode>());
ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size());
ASSERT_TRUE(new_output != cnode);
ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode);
// Test get pad mode.
auto pm_lower = MakeValue("pad");
auto pm_upper = MakeValue("PAD");
ASSERT_EQ(utils::GetPadMode(pm_lower), 0);
ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0);
ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0);
}
/// Feature: MindAPI
/// Description: test logging API.
/// Expectation: logging work as expected.
TEST_F(TestMindApi, test_api_logging) {
MS_LOG(DEBUG) << "hello debug";
MS_LOG(INFO) << "hello info";
MS_LOG(WARNING) << "hello warning";
MS_LOG(ERROR) << "hello error";
try {
MS_LOG(EXCEPTION) << "hello exception";
ASSERT_TRUE(false);
} catch (...) {
}
ASSERT_TRUE(true);
}
} // namespace mindspore::api