forked from mindspore-Ecosystem/mindspore
initial version
Signed-off-by: leonwanghui <leon.wanghui@huawei.com>
This commit is contained in:
commit
930a1fb0a8
|
@ -0,0 +1,152 @@
|
|||
---
|
||||
Language: Cpp
|
||||
# BasedOnStyle: Google
|
||||
AccessModifierOffset: -1
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlines: Left
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: true
|
||||
AllowShortLoopsOnASingleLine: true
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: Yes
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterClass: false
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
AfterExternBlock: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
SplitEmptyFunction: true
|
||||
SplitEmptyRecord: true
|
||||
SplitEmptyNamespace: true
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeInheritanceComma: false
|
||||
BreakInheritanceList: BeforeColon
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakConstructorInitializers: BeforeColon
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: true
|
||||
ColumnLimit: 120
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
CompactNamespaces: false
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 2
|
||||
Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: true
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
FixNamespaceComments: true
|
||||
ForEachMacros:
|
||||
# - foreach
|
||||
- Q_FOREACH
|
||||
- BOOST_FOREACH
|
||||
IncludeBlocks: Preserve
|
||||
IncludeCategories:
|
||||
- Regex: '^<ext/.*\.h>'
|
||||
Priority: 2
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 1
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IndentCaseLabels: true
|
||||
IndentPPDirectives: None
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBinPackProtocolList: Never
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: true
|
||||
PenaltyBreakAssignment: 2
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyBreakTemplateDeclaration: 10
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Left
|
||||
RawStringFormats:
|
||||
- Language: Cpp
|
||||
Delimiters:
|
||||
- cc
|
||||
- CC
|
||||
- cpp
|
||||
- Cpp
|
||||
- CPP
|
||||
- 'c++'
|
||||
- 'C++'
|
||||
CanonicalDelimiter: ''
|
||||
BasedOnStyle: google
|
||||
- Language: TextProto
|
||||
Delimiters:
|
||||
- pb
|
||||
- PB
|
||||
- proto
|
||||
- PROTO
|
||||
EnclosingFunctions:
|
||||
- EqualsProto
|
||||
- EquivToProto
|
||||
- PARSE_PARTIAL_TEXT_PROTO
|
||||
- PARSE_TEST_PROTO
|
||||
- PARSE_TEXT_PROTO
|
||||
- ParseTextOrDie
|
||||
- ParseTextProtoOrDie
|
||||
CanonicalDelimiter: ''
|
||||
BasedOnStyle: google
|
||||
ReflowComments: true
|
||||
SortIncludes: true
|
||||
SortUsingDeclarations: true
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceAfterTemplateKeyword: true
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeCpp11BracedList: false
|
||||
SpaceBeforeCtorInitializerColon: true
|
||||
SpaceBeforeInheritanceColon: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceBeforeRangeBasedForLoopColon: true
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Auto
|
||||
StatementMacros:
|
||||
- Q_UNUSED
|
||||
- QT_REQUIRE_VERSION
|
||||
TabWidth: 2
|
||||
UseTab: Never
|
||||
SortIncludes: false
|
||||
...
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# MindSpore
|
||||
build/
|
||||
mindspore/lib
|
||||
output
|
||||
*.ir
|
||||
|
||||
# Cmake files
|
||||
CMakeFiles/
|
||||
cmake_install.cmake
|
||||
CMakeCache.txt
|
||||
Makefile
|
||||
cmake-build-debug
|
||||
|
||||
# Dynamic libraries
|
||||
*.so
|
||||
*.so.*
|
||||
*.dylib
|
||||
|
||||
# Static libraries
|
||||
*.la
|
||||
*.lai
|
||||
*.a
|
||||
*.lib
|
||||
|
||||
# Protocol buffers
|
||||
*_pb2.py
|
||||
*.pb.h
|
||||
*.pb.cc
|
||||
|
||||
# Object files
|
||||
*.o
|
||||
|
||||
# Editor
|
||||
.vscode
|
||||
.idea/
|
||||
|
||||
# Cquery
|
||||
.cquery_cached_index/
|
||||
compile_commands.json
|
||||
|
||||
# Ctags and cscope
|
||||
tags
|
||||
TAGS
|
||||
CTAGS
|
||||
GTAGS
|
||||
GRTAGS
|
||||
GSYMS
|
||||
GPATH
|
||||
cscope.*
|
||||
|
||||
# Python files
|
||||
*__pycache__*
|
||||
.pytest_cache
|
||||
|
||||
# Mac files
|
||||
*.DS_Store
|
||||
|
||||
# Test results
|
||||
test_temp_summary_event_file/
|
||||
*.dot
|
||||
*.dat
|
||||
*.svg
|
||||
*.perf
|
||||
*.info
|
||||
*.ckpt
|
||||
*.shp
|
||||
*.pkl
|
||||
.clangd
|
||||
mindspore/version.py
|
||||
mindspore/default_config.py
|
||||
mindspore/.commit_id
|
||||
onnx.proto
|
||||
mindspore/ccsrc/onnx.proto
|
|
@ -0,0 +1,15 @@
|
|||
[submodule "third_party/flatbuffers"]
|
||||
path = third_party/flatbuffers
|
||||
url = https://github.com/google/flatbuffers.git
|
||||
[submodule "third_party/googletest"]
|
||||
path = third_party/googletest
|
||||
url = https://github.com/google/googletest.git
|
||||
[submodule "third_party/incubator-tvm"]
|
||||
path = third_party/incubator-tvm
|
||||
url = https://github.com/apache/incubator-tvm.git
|
||||
[submodule "third_party/protobuf"]
|
||||
path = third_party/protobuf
|
||||
url = https://github.com/protocolbuffers/protobuf.git
|
||||
[submodule "graphengine"]
|
||||
path = graphengine
|
||||
url = https://gitee.com/mindspore/graphengine.git
|
|
@ -0,0 +1,55 @@
|
|||
cmake_minimum_required(VERSION 3.14)
|
||||
project (MindSpore)
|
||||
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
|
||||
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)=' -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC")
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(PYBIND11_CPP_STANDARD -std=c++17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPTION_CXX_FLAGS}")
|
||||
|
||||
find_package(Threads)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/mind_expression.cmake)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include/flatbuffers)
|
||||
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_utils.cmake)
|
||||
find_package(Python3 COMPONENTS Interpreter Development)
|
||||
if(Python3_FOUND)
|
||||
set(PYTHON_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}")
|
||||
set(PYTHON_LIBRARIES "${Python3_LIBRARIES}")
|
||||
else()
|
||||
find_python_package(py_inc py_lib)
|
||||
set(PYTHON_INCLUDE_DIRS "${py_inc}")
|
||||
set(PYTHON_LIBRARIES "${py_lib}")
|
||||
endif()
|
||||
message("PYTHON_INCLUDE_DIRS = ${PYTHON_INCLUDE_DIRS}")
|
||||
message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
|
||||
set(MS_CCSRC_PATH ${CMAKE_SOURCE_DIR}/mindspore/ccsrc)
|
||||
set(MS_CCSRC_BUILD_PATH ${BUILD_PATH}/mindspore/mindspore/ccsrc)
|
||||
|
||||
if (ENABLE_GE)
|
||||
link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib)
|
||||
else()
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_graphengine.cmake)
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/framework)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
add_subdirectory(mindspore/ccsrc)
|
||||
if (ENABLE_TESTCASES)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
|
@ -0,0 +1,115 @@
|
|||
# MindSpore contributing guidelines
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [MindSpore contributing guidelines](#mindspore-contributing-guidelines)
|
||||
- [Contributor License Agreement](#contributor-license-agreement)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Contribution workflow](#contribution-workflow)
|
||||
- [Code style](#code-style)
|
||||
- [Fork-Pull development model](#fork-pull-development-model)
|
||||
- [Report issues](#report-issues)
|
||||
- [Propose PRs](#propose-prs)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## Contributor License Agreement
|
||||
|
||||
It's required to sign CLA before your first code submission to MindSpore community.
|
||||
|
||||
For individual contributor, please refer to [ICLA online document](https://www.mindspore.cn/icla) for the detailed information.
|
||||
|
||||
## Getting Started
|
||||
|
||||
- Fork the repository on [Github](https://github.com/mindspore-ai/mindspore) or [Gitee](https://gitee.com/mindspore/mindspore).
|
||||
- Read the [README.md](README.md) and [install page](https://www.mindspore.cn/install/en) for project information and build instructions.
|
||||
|
||||
## Contribution Workflow
|
||||
|
||||
### Code style
|
||||
|
||||
Please follow this style to make MindSpore easy to review, maintain and develop.
|
||||
|
||||
* Coding guidelines
|
||||
|
||||
The *Python* coding style suggested by [Python PEP 8 Coding Style](https://pep8.org/) and *C++* coding style suggested by [Google C++ Coding Guidelines](http://google.github.io/styleguide/cppguide.html) are used in MindSpore community.
|
||||
|
||||
* Unittest guidelines
|
||||
|
||||
The *Python* unittest style suggested by [pytest](http://www.pytest.org/en/latest/) and *C++* unittest style suggested by [Googletest Primer](https://github.com/google/googletest/blob/master/googletest/docs/primer.md) are used in MindSpore community.
|
||||
|
||||
### Fork-Pull development model
|
||||
|
||||
* Fork MindSpore repository
|
||||
|
||||
Before submitting code to MindSpore project, please make sure that this project have been forked to your own repository. It means that there will be parallel development between MindSpore repository and your own repository, so be careful to avoid the inconsistency between them.
|
||||
|
||||
* Clone the remote repository
|
||||
|
||||
If you want to download the code to the local machine, `git` is the best way:
|
||||
```shell
|
||||
# For GitHub
|
||||
git clone https://github.com/{insert_your_forked_repo}/mindspore.git
|
||||
git remote add upstream https://github.com/mindspore-ai/mindspore.git
|
||||
# For Gitee
|
||||
git clone https://gitee.com/{insert_your_forked_repo}/mindspore.git
|
||||
git remote add upstream https://gitee.com/mindspore/mindspore.git
|
||||
```
|
||||
|
||||
* Develop code locally
|
||||
|
||||
To avoid inconsistency between multiple branches, checking out to a new branch is `SUGGESTED`:
|
||||
```shell
|
||||
git checkout -b {new_branch_name} origin/master
|
||||
```
|
||||
|
||||
Then you can change the code arbitrarily.
|
||||
|
||||
* Push the code to the remote repository
|
||||
|
||||
After updating the code, you should push the update in the formal way:
|
||||
```shell
|
||||
git add .
|
||||
git status # Check the update status
|
||||
git commit -m "Your commit title"
|
||||
git commit -s --amend #Add the concrete description of your commit
|
||||
git push origin {new_branch_name}
|
||||
```
|
||||
|
||||
* Pull a request to MindSpore repository
|
||||
|
||||
In the last step, your need to pull a compare request between your new branch and MindSpore `master` branch. After finishing the pull request, the Jekins CI will be automatically set up for building test.
|
||||
|
||||
### Report issues
|
||||
|
||||
A great way to contribute to the project is to send a detailed report when you encounter an issue. We always appreciate a well-written, thorough bug report, and will thank you for it!
|
||||
|
||||
When reporting issues, refer to this format:
|
||||
|
||||
- What version of env (mindspore, os, python etc) are you using?
|
||||
- Is this a BUG REPORT or FEATURE REQUEST?
|
||||
- What happened?
|
||||
- What you expected to happen?
|
||||
- How to reproduce it?(as minimally and precisely as possible)
|
||||
- Special notes for your reviewers?
|
||||
|
||||
**Issues advisory:**
|
||||
|
||||
- **If you find an unclosed issue, which is exactly what you are going to solve,** please put some comments on that issue to tell others you would be in charge of it.
|
||||
- **If an issue is opened for a while,** it's recommended for contributors to precheck before working on solving that issue.
|
||||
- **If you resolve an issue which is reported by yourself,** it's also required to let others know before closing that issue.
|
||||
|
||||
### Propose PRs
|
||||
|
||||
* Raise your idea as an *issue* on [GitHub](https://github.com/mindspore-ai/mindspore/issues) or [Gitee](https://gitee.com/mindspore/mindspore/issues)
|
||||
* If it is a new feature that needs lots of design details, a design proposal should also be submitted.
|
||||
* After reaching consensus in the issue discussions and design proposal reviews, complete the development on the forked repo and submit a PR.
|
||||
* None of PRs is not permitted until it receives **2+ LGTM** from approvers. Please NOTICE that approver is NOT allowed to add *LGTM* on his own PR.
|
||||
* After PR is sufficiently discussed, it will get merged, abondoned or rejected depending on the outcome of the discussion.
|
||||
|
||||
**PRs advisory:**
|
||||
|
||||
- Any irrelevant changes should be avoided.
|
||||
- Make sure your commit history being ordered.
|
||||
- Always keep your branch up with the master branch.
|
||||
- For bug-fix PRs, make sure all related issues being linked.
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,2 @@
|
|||
MindSpore
|
||||
Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
@ -0,0 +1,124 @@
|
|||
![MindSpore Logo](docs/MindSpore-logo.png "MindSpore logo")
|
||||
============================================================
|
||||
|
||||
- [What is MindSpore?](#what-is-MindSpore)
|
||||
- [Automatic Differentiation](#automatic-differentiation)
|
||||
- [Automatic Parallel](#automatic-parallel)
|
||||
- [Installation](#installation)
|
||||
- [Binaries](#binaries)
|
||||
- [From Source](#from-source)
|
||||
- [Quickstart](#quickstart)
|
||||
- [Docs](#docs)
|
||||
- [Community](#community)
|
||||
- [Governance](#governance)
|
||||
- [Communication](#communication)
|
||||
- [Contributing](#contributing)
|
||||
- [Release Notes](#release-notes)
|
||||
- [License](#license)
|
||||
|
||||
## What Is MindSpore
|
||||
|
||||
MindSpore is a new open source deep learning training/inference framework that
|
||||
could be used for mobile, edge and cloud scenarios. MindSpore is designed to
|
||||
provide development experience with friendly design and efficient execution for
|
||||
the data scientists and algorithmic engineers, native support for Ascend AI
|
||||
processor, and software hardware co-optimization. At the meantime MindSpore as
|
||||
a global AI open source community, aims to further advance the development and
|
||||
enrichment of the AI software/hardware application ecosystem.
|
||||
|
||||
<img src="docs/MindSpore-architecture.png" alt="MindSpore Architecture" width="600"/>
|
||||
|
||||
For more details please check out our [Architecture Guide](https://www.mindspore.cn/docs/en/0.1.0-alpha/architecture.html).
|
||||
|
||||
### Automatic Differentiation
|
||||
|
||||
There are currently three automatic differentiation techniques in mainstream deep learning frameworks:
|
||||
|
||||
- **Conversion based on static compute graph**: Convert the network into a static data flow graph at compile time, then turn the chain rule into a data flow graph to implement automatic differentiation.
|
||||
- **Conversion based on dynamic compute graph**: Record the operation trajectory of the network during forward execution in an operator overloaded manner, then apply the chain rule to the dynamically generated data flow graph to implement automatic differentiation.
|
||||
- **Conversion based on source code**: This technology is evolving from the functional programming framework and performs automatic differential transformation on the intermediate expression (the expression form of the program during the compilation process) in the form of just-in-time compilation (JIT), supporting complex control flow scenarios, higher-order functions and closures.
|
||||
|
||||
TensorFlow adopted static calculation diagrams in the early days, whereas PyTorch used dynamic calculation diagrams. Static maps can utilize static compilation technology to optimize network performance, however, building a network or debugging it is very complicated. The use of dynamic graphics is very convenient, but it is difficult to achieve extreme optimization in performance.
|
||||
|
||||
But MindSpore finds another way, automatic differentiation based on source code conversion. On the one hand, it supports automatic differentiation of automatic control flow, so it is quite convenient to build models like PyTorch. On the other hand, MindSpore can perform static compilation optimization on neural networks to achieve great performance.
|
||||
|
||||
<img src="docs/Automatic-differentiation.png" alt="Automatic Differentiation" width="600"/>
|
||||
|
||||
The implementation of MindSpore automatic differentiation can be understood as the symbolic differentiation of the program itself. Because MindSpore IR is a functional intermediate expression, it has an intuitive correspondence with the composite function in basic algebra. The derivation formula of the composite function composed of arbitrary basic functions can be derived. Each primitive operation in MindSpore IR can correspond to the basic functions in basic algebra, which can build more complex flow control.
|
||||
|
||||
### Automatic Parallel
|
||||
|
||||
The goal of MindSpore automatic parallel is to build a training method that combines data parallelism, model parallelism, and hybrid parallelism. It can automatically select a least cost model splitting strategy to achieve automatic distributed parallel training.
|
||||
|
||||
<img src="docs/Automatic-parallel.png" alt="Automatic Parallel" width="600"/>
|
||||
|
||||
At present, MindSpore uses a fine-grained parallel strategy of splitting operators, that is, each operator in the figure is splited into a cluster to complete parallel operations. The splitting strategy during this period may be very complicated, but as a developer advocating Pythonic, you don't need to care about the underlying implementation, as long as the top-level API compute is efficient.
|
||||
|
||||
## Installation
|
||||
|
||||
### Binaries
|
||||
|
||||
MindSpore offers build options across multiple backends:
|
||||
|
||||
| Hardware Platform | Operating System | Status |
|
||||
| :---------------- | :--------------- | :----- |
|
||||
| Ascend910 | Ubuntu-x86 | ✔️ |
|
||||
| | EulerOS-x86 | ✔️ |
|
||||
| | EulerOS-aarch64 | ✔️ |
|
||||
| GPU CUDA 9.2 | Ubuntu-x86 | ✔️ |
|
||||
| GPU CUDA 10.1 | Ubuntu-x86 | ✔️ |
|
||||
| CPU | Ubuntu-x86 | ✔️ |
|
||||
|
||||
For installation using pip, take `Ubuntu-x86` and `CPU` build version as an example:
|
||||
|
||||
1. Download whl from [MindSpore website](https://www.mindspore.cn/), and install the package.
|
||||
|
||||
```
|
||||
pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/cpu/ubuntu-x86/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
```
|
||||
|
||||
2. Run the following command to verify the install.
|
||||
|
||||
```
|
||||
python -c 'import mindspore'
|
||||
```
|
||||
|
||||
### From Source
|
||||
|
||||
[Install MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
## Quickstart
|
||||
|
||||
See the [Quick Start](https://www.mindspore.cn/tutorial/en/0.1.0-alpha/quick_start/quick_start.html)
|
||||
to implement the image classification.
|
||||
|
||||
## Docs
|
||||
|
||||
More details about installation guide, tutorials and APIs, please see the
|
||||
[User Documentation](https://gitee.com/mindspore/docs).
|
||||
|
||||
## Community
|
||||
|
||||
### Governance
|
||||
|
||||
Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/community/blob/master/governance.md).
|
||||
|
||||
### Communication
|
||||
|
||||
- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers.
|
||||
- IRC channel at `#mindspore` (only for meeting minutes logging purpose)
|
||||
- Video Conferencing: meet.jit.si
|
||||
- Mailing-list: https://mailweb.mindspore.cn/postorius/lists
|
||||
|
||||
## Contributing
|
||||
|
||||
Welcome contributions. See our [Contributor Wiki](CONTRIBUTING.md) for
|
||||
more details.
|
||||
|
||||
## Release Notes
|
||||
|
||||
The release notes, see our [RELEASE](RELEASE.md).
|
||||
|
||||
## License
|
||||
|
||||
[Apache License 2.0](LICENSE)
|
|
@ -0,0 +1,73 @@
|
|||
# Release 0.1.0-alpha
|
||||
|
||||
## Main Features
|
||||
|
||||
### Ascend 910 Training and Inference Framework
|
||||
* Recommended OS: Ubuntu 16.04 (or later) or EulerOS 2.0
|
||||
* Python version: 3.7.5
|
||||
* Preset models
|
||||
* ResNet-50: residual structure-based convolutional neural network (CNN) for image classification, which is widely used.
|
||||
* AlexNet: classic CNN for image classification, achieving historical results in ImageNet LSVRC-2012.
|
||||
* LeNet: classic CNN for image classification, which was proposed by Yann LeCun.
|
||||
* VGG16: classic CNN for image classification, which was proposed by Oxford Visual Geometry Group.
|
||||
* YoloV3: real-time object detection network.
|
||||
* NEZHA: BERT-based Chinese pre-training network produced by Huawei Noah's Ark Laboratory.
|
||||
* Execution modes
|
||||
* Graph mode: provides graph optimization methods such as memory overcommitment, IR fusion, and buffer fusion to achieve optimal execution performance.
|
||||
* PyNative mode: single-step execution mode, facilitating process debugging.
|
||||
* Debugging capability and methods
|
||||
* Save CheckPoints and Summary data during training.
|
||||
* Support asynchronous printing.
|
||||
* Dump the computing data.
|
||||
* Support profiling analysis of the execution process performance.
|
||||
* Distributed execution
|
||||
* Support AllReduce, AllGather, and BroadCast collective communication.
|
||||
* AllReduce data parallel: Each device obtains different training data, which accelerates the overall training process.
|
||||
* Collective communication-based layerwise parallel: Models are divided and allocated to different devices to solve the problem of insufficient memory for large model processing and improve the training speed.
|
||||
* Automatic parallel mode: The better data and model parallel mode can be predicted based on the cost model. It is recommended that this mode be used on ResNet series networks.
|
||||
* Automatic differentiation
|
||||
* Implement automatic differentiation based on Source to Source.
|
||||
* Support distributed scenarios and automatic insertion of reverse communication operators.
|
||||
* Data processing, augmentation, and save format
|
||||
* Load common datasets such as ImageNet, MNIST, CIFAR-10, and CIFAR-100.
|
||||
* Support common data loading pipeline operations, such as shuffle, repeat, batch, map, and sampler.
|
||||
* Provide basic operator libraries to cover common CV scenarios.
|
||||
* Support users to customize Python data augmentation operators through the Pyfunc mechanism.
|
||||
* Support the access of user-defined datasets through the GeneratorDataset mechanism.
|
||||
* Provide the MindSpore data format, data aggregation and storage, random access example, data partition, efficient parallel read, user-defined index, and dataset search.
|
||||
* Convert user datasets to the MindSpore data format.
|
||||
* After data processing and augmentation, provide training applications in feed and graph modes.
|
||||
* FP32/16 mixed precision computation, supporting automatic and manual configuration
|
||||
* Provide common operators such as nn, math, and array, which can be customized.
|
||||
|
||||
### Inference Deployment
|
||||
* Deploy models in MindSpore format on the Ascend 310 platform for inference.
|
||||
* Save models in ONNX format.
|
||||
* Support saving models in LITE format and running models based on the lightweight inference framework.
|
||||
* Recommended OS: Android 4.3 or later
|
||||
* Supported network type: LeNet
|
||||
* Provide the generalization operators generated by TVM and operators generated after specific networks are tuned.
|
||||
|
||||
### Other Hardware Support
|
||||
* GPU platform training
|
||||
* Recommended OS: Ubuntu 16.04
|
||||
* CUDA version: 9.2 or 10.1
|
||||
* CuDNN version: 7.6 or later
|
||||
* Python version: 3.7.5
|
||||
* NCCL version: 2.4.8-1
|
||||
* OpenMPI version: 3.1.5
|
||||
* Supported models: AlexNet, LeNet, and LSTM
|
||||
* Supported datasets: MNIST and CIFAR-10
|
||||
* Support data parallel.
|
||||
* CPU platform training
|
||||
* Recommended OS: Ubuntu 16.04
|
||||
* Python version: 3.7.5
|
||||
* Supported model: LeNet
|
||||
* Supported dataset: MNIST
|
||||
* Provide only the stand-alone operation version.
|
||||
|
||||
## Peripherals and Tools
|
||||
* [MindSpore Official Website] (https://www.mindspore.cn/)
|
||||
* [MindInsight Visualization Debugging and Optimization] (https://gitee.com/mindspore/mindinsight)
|
||||
* [MindArmour Model Security Hardening Package] (https://gitee.com/mindspore/mindarmour)
|
||||
* [GraphEngine Computational Graph Engine] (https://gitee.com/mindspore/graphengine)
|
|
@ -0,0 +1,14 @@
|
|||
# Security Risk Description
|
||||
|
||||
1. When MindSpore is used for AI model training, if the user-defined computational graph structure (for example, Python code for generating the MindSpore computational graph) is provided by an untrusted third party, malicious code may exist and will be loaded and executed to attack the system.
|
||||
2. Model files are stored in binary mode. When MindSpore is used to optimize or infer AI models and the model files are loaded in deserialization mode, once malicious code is written into the model files, the code are loaded and executed, causing attacks on the system.
|
||||
3. MindSpore performs only model training and inference based on the data provided by users. Users need to protect data security to avoid privacy leakage.
|
||||
4. MindSpore is a distributed training platform. When MindSpore is used for distributed training, if an Ascend chip is used for training, a device provides a secure transmission protocol for gradient fusion. If GPUs or other clusters are used for training, identity authentication and secure transmission are not provided.
|
||||
|
||||
# Security Usage Suggestions
|
||||
|
||||
1. Run MindSpore in the sandbox.
|
||||
2. Run MindSpore as a non-root user.
|
||||
3. Ensure that the source of a computational graph structure is trustworthy. Do not write code irrelevant to model training in the network structure definition.
|
||||
4. Ensure that the source of a network model is trustworthy or enter secure network model parameters to prevent model parameters from being tampered with.
|
||||
5. Ensure that GPU distributed training is performed on an isolated cluster network.
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
git submodule update --init --recursive
|
||||
|
||||
|
|
@ -0,0 +1,468 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
BASEPATH=$(cd "$(dirname $0)"; pwd)
|
||||
PROJECT_PATH="${BASEPATH}"
|
||||
CUDA_PATH=""
|
||||
CUDNN_PATH=""
|
||||
export BUILD_PATH="${BASEPATH}/build/"
|
||||
# print usage message
|
||||
usage()
|
||||
{
|
||||
echo "Usage:"
|
||||
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-s] [-b ge|cpu] [-m infer|train] \\"
|
||||
echo " [-a on|off] [-g on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
|
||||
echo " [-P on|off] [-z] [-M on|off] [-V 9.2|10.1] [-I] [-K]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -d Debug mode"
|
||||
echo " -r Release mode, default mode"
|
||||
echo " -v Display build command"
|
||||
echo " -c Enable code coverage switch, default off"
|
||||
echo " -t Run testcases switch, default on"
|
||||
echo " -g Use glog to output log, default on"
|
||||
echo " -h Print usage"
|
||||
echo " -s Install or setup"
|
||||
echo " -b Select other backend, available: \\"
|
||||
echo " ge:graph engine, cpu"
|
||||
echo " -m Select mode, available: infer, train, default is infer "
|
||||
echo " -a Enable ASAN, default off"
|
||||
echo " -p Enable pipeline profile, default off"
|
||||
echo " -i Enable increment building, default off"
|
||||
echo " -L Enable load ANF-IR as input of 'infer', default off"
|
||||
echo " -R Enable the time_line record, default off"
|
||||
echo " -j[n] Set the threads when building (Default: -j8)"
|
||||
echo " -e Use gpu, d or cpu"
|
||||
echo " -P Enable dump anf graph to file in ProtoBuffer format, default on"
|
||||
echo " -Q Enable dump end to end, default off"
|
||||
echo " -D Enable dumping of function graph ir, default on"
|
||||
echo " -z Compile dataset & mindrecord, default off"
|
||||
echo " -M Enable MPI and NCCL for GPU training, default off"
|
||||
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
|
||||
echo " -I Compile predict, default off"
|
||||
echo " -K Compile with AKG, default off"
|
||||
}
|
||||
|
||||
# check value of input is 'on' or 'off'
|
||||
# usage: check_on_off arg_value arg_name
|
||||
check_on_off()
|
||||
{
|
||||
if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then
|
||||
echo "Invalid value $1 for option -$2"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# check and set options
|
||||
checkopts()
|
||||
{
|
||||
# Init default values of build options
|
||||
THREAD_NUM=8
|
||||
DEBUG_MODE="off"
|
||||
VERBOSE=""
|
||||
ENABLE_COVERAGE="off"
|
||||
RUN_TESTCASES="off"
|
||||
EXECUTE_SETUP="off"
|
||||
ENABLE_BACKEND=""
|
||||
TRAIN_MODE="INFER"
|
||||
ENABLE_ASAN="off"
|
||||
ENABLE_PROFILE="off"
|
||||
INC_BUILD="off"
|
||||
ENABLE_LOAD_IR="off"
|
||||
ENABLE_TIMELINE="off"
|
||||
ENABLE_DUMP2PROTO="on"
|
||||
ENABLE_DUMPE2E="off"
|
||||
ENABLE_DUMP_IR="on"
|
||||
COMPILE_MINDDATA="off"
|
||||
ENABLE_MPI="off"
|
||||
CUDA_VERSION="9.2"
|
||||
COMPILE_PREDICT="off"
|
||||
USE_GLOG="on"
|
||||
PREDICT_PLATFORM=""
|
||||
ENABLE_AKG="off"
|
||||
|
||||
# Process the options
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K' opt
|
||||
do
|
||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||
case "${opt}" in
|
||||
d)
|
||||
DEBUG_MODE="on"
|
||||
;;
|
||||
r)
|
||||
DEBUG_MODE="off"
|
||||
;;
|
||||
v)
|
||||
VERBOSE="VERBOSE=1"
|
||||
;;
|
||||
j)
|
||||
THREAD_NUM=$OPTARG
|
||||
;;
|
||||
c)
|
||||
check_on_off $OPTARG c
|
||||
ENABLE_COVERAGE="$OPTARG"
|
||||
;;
|
||||
t)
|
||||
check_on_off $OPTARG t
|
||||
RUN_TESTCASES="$OPTARG"
|
||||
;;
|
||||
g)
|
||||
check_on_off $OPTARG g
|
||||
USE_GLOG="$OPTARG"
|
||||
;;
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
s)
|
||||
EXECUTE_SETUP="on"
|
||||
;;
|
||||
b)
|
||||
if [[ "X$OPTARG" != "Xge" && "X$OPTARG" != "Xcpu" ]]; then
|
||||
echo "Invalid value ${OPTARG} for option -b"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
ENABLE_BACKEND=$(echo "$OPTARG" | tr '[a-z]' '[A-Z]')
|
||||
if [[ "X$ENABLE_BACKEND" == "XGE" ]]; then
|
||||
ENABLE_GE="on"
|
||||
fi
|
||||
if [[ "X$ENABLE_BACKEND" != "XCPU" ]]; then
|
||||
ENABLE_CPU="on"
|
||||
fi
|
||||
;;
|
||||
a)
|
||||
check_on_off $OPTARG a
|
||||
ENABLE_ASAN="$OPTARG"
|
||||
;;
|
||||
p)
|
||||
check_on_off $OPTARG p
|
||||
ENABLE_PROFILE="$OPTARG"
|
||||
;;
|
||||
i)
|
||||
INC_BUILD="on"
|
||||
;;
|
||||
m)
|
||||
if [[ "X$OPTARG" != "Xinfer" && "X$OPTARG" != "Xtrain" ]]; then
|
||||
echo "Invalid value ${OPTARG} for option -m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
TRAIN_MODE=$(echo "$OPTARG" | tr '[a-z]' '[A-Z]')
|
||||
;;
|
||||
L)
|
||||
ENABLE_LOAD_IR="on"
|
||||
echo "build with enable load anf ir"
|
||||
;;
|
||||
R)
|
||||
ENABLE_TIMELINE="on"
|
||||
echo "enable time_line record"
|
||||
;;
|
||||
e)
|
||||
if [[ "X$OPTARG" == "Xgpu" ]]; then
|
||||
ENABLE_GPU="on"
|
||||
ENABLE_CPU="on"
|
||||
elif [[ "X$OPTARG" == "Xd" ]]; then
|
||||
ENABLE_D="on"
|
||||
ENABLE_CPU="on"
|
||||
elif [[ "X$OPTARG" == "Xcpu" ]]; then
|
||||
ENABLE_CPU="on"
|
||||
else
|
||||
echo "Invalid value ${OPTARG} for option -e"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
M)
|
||||
check_on_off $OPTARG M
|
||||
ENABLE_MPI="$OPTARG"
|
||||
;;
|
||||
V)
|
||||
if [[ "X$OPTARG" != "X9.2" && "X$OPTARG" != "X10.1" ]]; then
|
||||
echo "Invalid value ${OPTARG} for option -V"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
CUDA_VERSION="$OPTARG"
|
||||
;;
|
||||
P)
|
||||
check_on_off $OPTARG p
|
||||
ENABLE_DUMP2PROTO="$OPTARG"
|
||||
echo "enable dump anf graph to proto file"
|
||||
;;
|
||||
Q)
|
||||
check_on_off $OPTARG Q
|
||||
ENABLE_DUMPE2E="$OPTARG"
|
||||
echo "enable dump end to end"
|
||||
;;
|
||||
D)
|
||||
check_on_off $OPTARG D
|
||||
ENABLE_DUMP_IR="$OPTARG"
|
||||
echo "enable dump function graph ir"
|
||||
;;
|
||||
z)
|
||||
COMPILE_MINDDATA="on"
|
||||
;;
|
||||
I)
|
||||
COMPILE_PREDICT="on"
|
||||
if [[ "$OPTARG" == "arm64" ]]; then
|
||||
PREDICT_PLATFORM="arm64"
|
||||
elif [[ "$OPTARG" == "x86_64" ]]; then
|
||||
PREDICT_PLATFORM="x86_64"
|
||||
else
|
||||
echo "-I parameter must be arm64 or x86_64"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
K)
|
||||
ENABLE_AKG="on"
|
||||
echo "enable compile with akg"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
usage
|
||||
exit 1
|
||||
esac
|
||||
done
|
||||
}
|
||||
checkopts "$@"
|
||||
echo "---------------- mindspore: build start ----------------"
|
||||
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
|
||||
git submodule update --init graphengine
|
||||
|
||||
build_exit()
|
||||
{
|
||||
echo "$@" >&2
|
||||
stty echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create building path
|
||||
build_mindspore()
|
||||
{
|
||||
echo "start build mindspore project."
|
||||
mkdir -pv "${BUILD_PATH}/mindspore"
|
||||
cd "${BUILD_PATH}/mindspore"
|
||||
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH"
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_LOAD_ANF_IR=$ENABLE_LOAD_IR"
|
||||
if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON"
|
||||
fi
|
||||
if [[ "X$RUN_TESTCASES" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TESTCASES=ON"
|
||||
fi
|
||||
if [[ -n "$ENABLE_BACKEND" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${ENABLE_BACKEND}=ON"
|
||||
fi
|
||||
if [[ -n "$TRAIN_MODE" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${TRAIN_MODE}=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_ASAN" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ASAN=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_PROFILE" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PROFILE=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_TIMELINE" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TIMELINE=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_DUMP2PROTO" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_PROTO=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_DUMPE2E" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_E2E=ON"
|
||||
fi
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_IR=${ENABLE_DUMP_IR^^}"
|
||||
if [[ "X$ENABLE_MPI" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MPI=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_D=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_GPU" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DCUDA_PATH=$CUDA_PATH -DCUDNN_PATH=$CUDNN_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}"
|
||||
fi
|
||||
if [[ "X$ENABLE_CPU" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON"
|
||||
fi
|
||||
if [[ "X$COMPILE_MINDDATA" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MINDDATA=ON"
|
||||
fi
|
||||
if [[ "X$USE_GLOG" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
||||
fi
|
||||
echo "${CMAKE_ARGS}"
|
||||
if [[ "X$INC_BUILD" = "Xoff" ]]; then
|
||||
cmake ${CMAKE_ARGS} ../..
|
||||
fi
|
||||
make ${VERBOSE} -j$THREAD_NUM
|
||||
if [[ "X$EXECUTE_SETUP" = "Xon" ]]; then
|
||||
make install
|
||||
fi
|
||||
echo "success to build mindspore project!"
|
||||
}
|
||||
|
||||
build_predict()
|
||||
{
|
||||
git submodule update --init --recursive third_party/incubator-tvm
|
||||
echo "start build predict project"
|
||||
|
||||
git submodule update --init --recursive third_party/flatbuffers
|
||||
git submodule update --init --recursive third_party/googletest
|
||||
git submodule update --init --recursive third_party/protobuf
|
||||
|
||||
rm -rf "${BASEPATH}/predict/build"
|
||||
mkdir -pv "${BASEPATH}/predict/build"
|
||||
rm -rf "${BASEPATH}/predict/output"
|
||||
mkdir -pv "${BASEPATH}/predict/output"
|
||||
|
||||
if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
|
||||
if [ "${ANDROID_NDK}" ]; then
|
||||
echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m"
|
||||
else
|
||||
echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r16b/ \e[0m"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
#build flatbuf
|
||||
cd "${BASEPATH}/third_party/flatbuffers"
|
||||
rm -rf build && mkdir -p build && cd build && cmake .. && make -j$THREAD_NUM
|
||||
FLATC="${BASEPATH}"/third_party/flatbuffers/build/flatc
|
||||
cd "${BASEPATH}"/predict/schema && mkdir -p "${BASEPATH}"/predict/schema/inner
|
||||
find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b
|
||||
find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o ${BASEPATH}/predict/schema/inner
|
||||
|
||||
# check LLVM_PATH
|
||||
if [ "${LLVM_PATH}" == "" ]; then
|
||||
echo "Please set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config"
|
||||
exit
|
||||
fi
|
||||
|
||||
#build tvm
|
||||
tvm_open_source="${BASEPATH}/third_party/incubator-tvm"
|
||||
tvm_kernel_build="${BASEPATH}/predict/module/tvm_kernel"
|
||||
if [ ! -f "${tvm_kernel_build}"/incubator-tvm/build/libtvm.so ]; then
|
||||
rm -fr "${tvm_kernel_build}"/incubator-tvm
|
||||
cp -fr "${tvm_open_source}" "${tvm_kernel_build}"
|
||||
mkdir -p "${tvm_kernel_build}"/incubator-tvm/build
|
||||
patch -d "${tvm_kernel_build}"/incubator-tvm -p1 < "${BASEPATH}"/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch
|
||||
cp "${tvm_kernel_build}"/lite/src/codegen/llvm/lite_rtfunc_reset.cc "${tvm_kernel_build}"/incubator-tvm/src/codegen/llvm/
|
||||
cp "${tvm_open_source}"/cmake/config.cmake "${tvm_kernel_build}"/incubator-tvm
|
||||
if [ "${LLVM_PATH}" ]; then
|
||||
sed -i "s#set(USE_LLVM .*)#set(USE_LLVM \"${LLVM_PATH}\")#g" "${tvm_kernel_build}"/incubator-tvm/config.cmake
|
||||
else
|
||||
echo "need set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config"
|
||||
fi
|
||||
cd "${tvm_kernel_build}"/incubator-tvm/build
|
||||
cmake ..
|
||||
make -j$THREAD_NUM
|
||||
else
|
||||
cd "${tvm_kernel_build}"/incubator-tvm/build
|
||||
make -j$THREAD_NUM
|
||||
fi
|
||||
|
||||
#gen op
|
||||
predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_x86"
|
||||
predict_platform="x86"
|
||||
if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
|
||||
predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_arm64"
|
||||
predict_platform="arm64"
|
||||
fi
|
||||
|
||||
need_get_libs=true
|
||||
if [ -d "${predict_tvm_op_lib_path}" ]; then
|
||||
file_list=$(ls "${predict_tvm_op_lib_path}")
|
||||
if [ -n "${file_list}" ]; then
|
||||
libstime=$(stat -c %Y "${predict_tvm_op_lib_path}"/* | sort -u | tail -n1)
|
||||
pythontime=$(find "${BASEPATH}"/predict/module/tvm_kernel/lite/python/ -name "*.py" -exec stat -c %Y {} \; |
|
||||
sort -u | tail -n1)
|
||||
if [ "${libstime}" -ge "${pythontime}" ]; then
|
||||
need_get_libs=false
|
||||
else
|
||||
rm -fr "${predict_tvm_op_lib_path}"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if $need_get_libs; then
|
||||
PYTHONPATH_OLD=${PYTHONPATH}
|
||||
export PYTHONPATH="${tvm_kernel_build}/incubator-tvm/python:${tvm_kernel_build}/incubator-tvm/topi/python:${tvm_kernel_build}/incubator-tvm/nnvm/python:${tvm_kernel_build}/lite/python:"
|
||||
cd "${BASEPATH}"/predict/module/tvm_kernel/lite/python/at_ops
|
||||
python3 at_gen_strip.py ${predict_platform}
|
||||
export PYTHONPATH=${PYTHONPATH_OLD}
|
||||
fi
|
||||
|
||||
cd "${BASEPATH}/predict/build"
|
||||
if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
|
||||
-DANDROID_NATIVE_API_LEVEL=android-19 -DANDROID_NDK="${ANDROID_NDK}" \
|
||||
-DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" -DANDROID_STL="c++_shared" \
|
||||
-DANDROID_ABI="arm64-v8a" -DENABLE_PREDICT_ARM64=ON -DANDROID_ALLOW_UNDEFINED_SYMBOLS=TRUE ..
|
||||
elif [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then
|
||||
cmake ..
|
||||
fi
|
||||
|
||||
make ${VERBOSE} -j$THREAD_NUM
|
||||
if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then
|
||||
cd "${BASEPATH}/predict/build/test" && ./run_tests.sh
|
||||
fi
|
||||
|
||||
# copy securec include files
|
||||
mkdir -p "${BASEPATH}/predict/output/include/securec/include"
|
||||
cp "${BASEPATH}"/third_party/securec/include/* "${BASEPATH}"/predict/output/include/securec/include
|
||||
|
||||
cd "${BASEPATH}/predict/output/"
|
||||
if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then
|
||||
tar -cf MSPredict-0.1.0-linux_x86_64.tar.gz include/ lib/ --warning=no-file-changed
|
||||
elif [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
|
||||
tar -cf MSPredict-0.1.0-linux_aarch64.tar.gz include/ lib/ --warning=no-file-changed
|
||||
fi
|
||||
echo "success to build predict project!"
|
||||
}
|
||||
|
||||
if [[ "X$COMPILE_PREDICT" = "Xon" ]]; then
|
||||
build_predict
|
||||
echo "---------------- mindspore: build end ----------------"
|
||||
exit
|
||||
else
|
||||
build_mindspore
|
||||
fi
|
||||
|
||||
if [[ "X$INC_BUILD" = "Xoff" ]]; then
|
||||
if [[ "X$ENABLE_GE" = "Xon" ]]; then
|
||||
bash "${PROJECT_PATH}/package.sh" ge
|
||||
elif [[ "X$ENABLE_GPU" = "Xon" ]]; then
|
||||
bash "${PROJECT_PATH}/package.sh" ms gpu
|
||||
elif [[ "X$ENABLE_D" = "Xon" ]] || [[ "X$ENABLE_CPU" = "Xon" ]]; then
|
||||
bash "${PROJECT_PATH}/package.sh" ms
|
||||
else
|
||||
bash "${PROJECT_PATH}/package.sh" debug
|
||||
fi
|
||||
fi
|
||||
|
||||
cp -rf ${BUILD_PATH}/package/mindspore/lib ${BUILD_PATH}/../mindspore
|
||||
cp -rf ${BUILD_PATH}/package/mindspore/*.so ${BUILD_PATH}/../mindspore
|
||||
|
||||
if [[ -d "${BUILD_PATH}/package/build" ]]; then
|
||||
rm -rf "${BUILD_PATH}/package/build"
|
||||
fi
|
||||
echo "---------------- mindspore: build end ----------------"
|
|
@ -0,0 +1,70 @@
|
|||
message(STATUS "ENABLE_GE set to FALSE, compiling GraphEngine")
|
||||
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/graphengine)
|
||||
|
||||
message(STATUS "ge dir: ${GE_SOURCE_DIR}")
|
||||
# download json headers, rather than whole repository
|
||||
include(${GE_SOURCE_DIR}/cmake/ge_utils.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/json.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
|
||||
|
||||
# for CPU/GPU mode, find c_sec and slog from local prebuild
|
||||
if (NOT ENABLE_D)
|
||||
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH})
|
||||
find_library(slog libslog.so ${GE_PREBUILD_PATH})
|
||||
elseif (DEFINED ENV{D_LINK_PATH})
|
||||
set(GE_LIB_PATH $ENV{D_LINK_PATH})
|
||||
set(GE_SYS_ARCH "")
|
||||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
# x86 ubuntu
|
||||
set(GE_SYS_ARCH "x86_64")
|
||||
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||
# arm euleros
|
||||
set(GE_SYS_ARCH "aarch64")
|
||||
else()
|
||||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
|
||||
endif()
|
||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
|
||||
find_library(c_sec libc_sec.so ${GE_LIB_PATH})
|
||||
find_library(slog libslog.so ${GE_LIB_PATH})
|
||||
find_library(mmpa libmmpa.so ${GE_LIB_PATH})
|
||||
find_library(runtime libruntime.so ${GE_LIB_PATH})
|
||||
find_library(msprof libmsprof.so ${GE_LIB_PATH})
|
||||
find_library(register libregister.so ${GE_LIB_PATH})
|
||||
find_library(hccl libhccl.so ${GE_LIB_PATH})
|
||||
find_library(cce libcce.so ${GE_LIB_PATH})
|
||||
find_library(resource libresource.so ${GE_LIB_PATH})
|
||||
else()
|
||||
set(HIAI_INSTALLED_DIR /usr/local/HiAI)
|
||||
set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64)
|
||||
set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/runtime/lib64)
|
||||
find_library(c_sec libc_sec.so ${HIAI_DRIVER_DIR})
|
||||
find_library(slog libslog.so ${HIAI_DRIVER_DIR})
|
||||
find_library(mmpa libmmpa.so ${HIAI_DRIVER_DIR})
|
||||
|
||||
find_library(cce libcce.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(hccl libhccl.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(runtime libruntime.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(msprof libmsprof.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(register libregister.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(resource libresource.so ${HIAI_RUNTIME_DIR})
|
||||
endif()
|
||||
|
||||
# compile libraries from following directories
|
||||
# this cmake file is called only when NOT ENABLE_GE is set
|
||||
set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
|
||||
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
# force __FILE__ to show relative path of file, from source directory
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined")
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/common/graph)
|
||||
if(ENABLE_D)
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/ge/common)
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS})
|
|
@ -0,0 +1,36 @@
|
|||
# googletest library
|
||||
#
|
||||
#
|
||||
# GTest_LIBRARY
|
||||
#
|
||||
|
||||
if (NOT TARGET gtest)
|
||||
set(BUILD_TESTING OFF CACHE BOOL "Disable glog test")
|
||||
|
||||
set(_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE ${CMAKE_POSITION_INDEPENDENT_CODE})
|
||||
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
set(_ms_tmp_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||
set(_ms_tmp_CMAKE_MACOSX_RPATH ${CMAKE_MACOSX_RPATH})
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(BUILD_SHARED_LIBS ON)
|
||||
set(CMAKE_MACOSX_RPATH TRUE)
|
||||
set(CMAKE_CXX_FLAGS "${SECURE_CXX_FLAGS}")
|
||||
|
||||
if (CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "5.0" AND CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64" AND SYSTEM_TYPE MATCHES "euleros")
|
||||
# -D_GLIBCXX_USE_CXX11_ABI=0 added for the ABI incompatible for libtsdclient.so
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest ${CMAKE_BINARY_DIR}/googletest)
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ${_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE})
|
||||
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
|
||||
set(BUILD_SHARED_LIBS ${_ms_tmp_BUILD_SHARED_LIBS})
|
||||
set(CMAKE_MACOSX_RPATH ${_ms_tmp_CMAKE_MACOSX_RPATH})
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/include)
|
||||
|
||||
set(GTEST_LIBRARY gtest)
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
# Protobuf
|
||||
#
|
||||
#
|
||||
# PROTOBUF_LIBRARY - Link this to use protobuf
|
||||
#
|
||||
if (NOT TARGET protobuf::libprotobuf)
|
||||
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Disable protobuf test")
|
||||
set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "Gen shared library")
|
||||
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
|
||||
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/protobuf/cmake ${CMAKE_BINARY_DIR}/protobuf)
|
||||
|
||||
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
|
||||
endif ()
|
||||
|
||||
include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/protobuf/src)
|
||||
|
||||
set(PROTOBUF_LIBRARY protobuf::libprotobuf)
|
||||
set(PROTOC_EXECUTABLE $<TARGET_FILE:protobuf::protoc>)
|
||||
|
||||
function(ms_protobuf_generate c_var h_var)
|
||||
if(NOT ARGN)
|
||||
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(${c_var})
|
||||
set(${h_var})
|
||||
|
||||
foreach(file ${ARGN})
|
||||
get_filename_component(abs_file ${file} ABSOLUTE)
|
||||
get_filename_component(file_name ${file} NAME_WE)
|
||||
get_filename_component(file_dir ${abs_file} PATH)
|
||||
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
|
||||
|
||||
|
||||
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
|
||||
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
|
||||
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
|
||||
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
DEPENDS protobuf::protoc ${abs_file}
|
||||
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
|
||||
endforeach()
|
||||
|
||||
set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
|
||||
set(${c_var} ${${c_var}} PARENT_SCOPE)
|
||||
set(${h_var} ${${h_var}} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(ms_protobuf_generate_py c_var h_var py_var)
|
||||
if(NOT ARGN)
|
||||
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(${c_var})
|
||||
set(${h_var})
|
||||
set(${py_var})
|
||||
|
||||
foreach(file ${ARGN})
|
||||
get_filename_component(abs_file ${file} ABSOLUTE)
|
||||
get_filename_component(file_name ${file} NAME_WE)
|
||||
get_filename_component(file_dir ${abs_file} PATH)
|
||||
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
|
||||
|
||||
|
||||
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
|
||||
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
|
||||
list(APPEND ${py_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
|
||||
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
|
||||
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
|
||||
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
COMMAND perl -pi -e "s/import (.+_pb2.*)/from . import \\1/" "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
|
||||
COMMAND cp "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py" "${PROJECT_SOURCE_DIR}/mindspore/train/"
|
||||
DEPENDS protobuf::protoc ${abs_file}
|
||||
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
|
||||
endforeach()
|
||||
|
||||
set_source_files_properties(${${c_var}} ${${h_var}} ${${py_var}} PROPERTIES GENERATED TRUE)
|
||||
set(${c_var} ${${c_var}} PARENT_SCOPE)
|
||||
set(${h_var} ${${h_var}} PARENT_SCOPE)
|
||||
set(${py_var} ${${py_var}} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
|
@ -0,0 +1,19 @@
|
|||
# securec library
|
||||
#
|
||||
#
|
||||
# SECUREC_LIBRARY
|
||||
#
|
||||
|
||||
if (NOT TARGET securec)
|
||||
set(_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE ${CMAKE_POSITION_INDEPENDENT_CODE})
|
||||
set(_ms_tmp_CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
|
||||
set(CMAKE_C_FLAGS "${SECURE_CXX_FLAGS}")
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/securec ${CMAKE_BINARY_DIR}/securec)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ${_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE})
|
||||
set(CMAKE_C_FLAGS ${_ms_tmp_CMAKE_C_FLAGS})
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/securec/include)
|
||||
|
||||
set(SECUREC_LIBRARY securec)
|
|
@ -0,0 +1,25 @@
|
|||
# MS Utils
|
||||
#
|
||||
|
||||
function(find_python_package out_inc out_lib)
|
||||
# Use PYTHON_EXECUTABLE if it is defined, otherwise default to python
|
||||
if ("${PYTHON_EXECUTABLE}" STREQUAL "")
|
||||
set(PYTHON_EXECUTABLE "python3")
|
||||
else()
|
||||
set(PYTHON_EXECUTABLE "${PYTHON_EXECUTABLE}")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())"
|
||||
RESULT_VARIABLE result
|
||||
OUTPUT_VARIABLE inc)
|
||||
string(STRIP "${inc}" inc)
|
||||
set(${out_inc} ${inc} PARENT_SCOPE)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${PYTHON_EXECUTABLE}" -c "import distutils.sysconfig as sysconfig; import os; print(os.path.join(sysconfig.get_config_var('LIBDIR'), sysconfig.get_config_var('LDLIBRARY')))"
|
||||
RESULT_VARIABLE result
|
||||
OUTPUT_VARIABLE lib)
|
||||
string(STRIP "${lib}" lib)
|
||||
set(${out_lib} ${lib} PARENT_SCOPE)
|
||||
endfunction()
|
|
@ -0,0 +1,7 @@
|
|||
mindspore_add_pkg(dlpack
|
||||
VER 0.2
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/dmlc/dlpack/archive/0acb731e0e43d15deee27b66f10e4c5b4e667913.zip
|
||||
MD5 6b8093f17ad4e830d3c63eb3171c4b45)
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
mindspore_add_pkg(dmlc_core
|
||||
VER 0.3
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/dmlc/dmlc-core/archive/808f485387f9a03f78fa9f1159f387d0d91b7a28.zip
|
||||
MD5 ea36f94c57752bf40fb02dfc362f1ed9)
|
||||
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
set(Eigen3_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(Eigen3_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(Eigen3
|
||||
VER 3.3.7
|
||||
URL https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.tar.gz
|
||||
MD5 9e30f67e8531477de4117506fe44669b
|
||||
CMAKE_OPTION -DBUILD_TESTING=OFF)
|
||||
find_package(Eigen3 3.3.7 REQUIRED ${MS_FIND_NO_DEFAULT_PATH})
|
||||
include_directories(${Eigen3_INC})
|
||||
include_directories(${EIGEN3_INCLUDE_DIR})
|
||||
set_property(TARGET Eigen3::Eigen PROPERTY IMPORTED_GLOBAL TRUE)
|
||||
add_library(mindspore::eigen ALIAS Eigen3::Eigen)
|
|
@ -0,0 +1,53 @@
|
|||
set(flatbuffers_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(flatbuffers_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(flatbuffers
|
||||
VER 1.11.0
|
||||
LIBS flatbuffers
|
||||
EXE flatc
|
||||
URL https://github.com/google/flatbuffers/archive/v1.11.0.tar.gz
|
||||
MD5 02c64880acb89dbd57eebacfd67200d8
|
||||
CMAKE_OPTION -DFLATBUFFERS_BUILD_TESTS=OFF )
|
||||
|
||||
include_directories(${flatbuffers_INC})
|
||||
add_library(mindspore::flatbuffers ALIAS flatbuffers::flatbuffers)
|
||||
add_executable(mindspore::flatc ALIAS flatbuffers::flatc)
|
||||
include_directories(${flatbuffers_INC})
|
||||
function(ms_build_flatbuffers source_schema_files
|
||||
source_schema_dirs
|
||||
custom_target_name
|
||||
generated_output_dir)
|
||||
|
||||
set(total_schema_dirs "")
|
||||
set(total_generated_files "")
|
||||
set(FLATC mindspore::flatc)
|
||||
foreach (schema_dir ${source_schema_dirs})
|
||||
set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs})
|
||||
endforeach()
|
||||
|
||||
foreach(schema ${source_schema_files})
|
||||
get_filename_component(filename ${schema} NAME_WE)
|
||||
if (NOT ${generated_output_dir} STREQUAL "")
|
||||
set(generated_file ${generated_output_dir}/${filename}_generated.h)
|
||||
add_custom_command(
|
||||
OUTPUT ${generated_file}
|
||||
COMMAND ${FLATC} --gen-mutable
|
||||
--reflect-names --gen-object-api -o ${generated_output_dir}
|
||||
${total_schema_dirs}
|
||||
-c -b --reflect-types ${schema}
|
||||
DEPENDS ${FLATC} ${schema}
|
||||
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
COMMENT "Running C++ flatbuffers compiler on ${schema}" VERBATIM)
|
||||
list(APPEND total_generated_files ${generated_file})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_custom_target(${custom_target_name} ALL
|
||||
DEPENDS ${total_generated_files})
|
||||
|
||||
if (NOT ${generated_output_dir} STREQUAL "")
|
||||
include_directories(${generated_output_dir})
|
||||
set_property(TARGET ${custom_target_name}
|
||||
PROPERTY GENERATED_OUTPUT_DIR
|
||||
${generated_output_dir})
|
||||
endif()
|
||||
endfunction()
|
|
@ -0,0 +1,10 @@
|
|||
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS}")
|
||||
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(glog
|
||||
VER 0.4.0
|
||||
LIBS glog
|
||||
URL https://github.com/google/glog/archive/v0.4.0.tar.gz
|
||||
MD5 0daea8785e6df922d7887755c3d100d0
|
||||
CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON)
|
||||
include_directories(${glog_INC})
|
||||
add_library(mindspore::glog ALIAS glog::glog)
|
|
@ -0,0 +1,13 @@
|
|||
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(gtest
|
||||
VER 1.8.0
|
||||
LIBS gtest
|
||||
URL https://github.com/google/googletest/archive/release-1.8.0.tar.gz
|
||||
MD5 16877098823401d1bf2ed7891d7dce36
|
||||
CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON
|
||||
-DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON)
|
||||
include_directories(${gtest_INC})
|
||||
add_library(mindspore::gtest ALIAS gtest::gtest)
|
||||
file(COPY ${gtest_LIBPATH}/libgtest.so DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest)
|
||||
file(COPY ${gtest_LIBPATH}/libgtest_main.so DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest)
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
set(jpeg_turbo_USE_STATIC_LIBS ON)
|
||||
set(jpeg_turbo_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(jpeg_turbo_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
mindspore_add_pkg(jpeg_turbo
|
||||
VER 2.0.4
|
||||
LIBS jpeg
|
||||
URL https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.4.tar.gz
|
||||
MD5 44c43e4a9fb352f47090804529317c88
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DCMAKE_SKIP_RPATH=TRUE
|
||||
)
|
||||
include_directories(${jpeg_turbo_INC})
|
||||
add_library(mindspore::jpeg_turbo ALIAS jpeg_turbo::jpeg)
|
|
@ -0,0 +1,9 @@
|
|||
set(nlohmann_json_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(nlohmann_json_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(nlohmann_json
|
||||
VER 3.6.1
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip
|
||||
MD5 0dc903888211db3a0f170304cd9f3a89)
|
||||
include_directories(${nlohmann_json_INC})
|
||||
add_library(mindspore::json ALIAS nlohmann_json)
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
set(tiff_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -Wno-unused-result \
|
||||
-Wno-unused-but-set-variable -fPIC -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(tiff_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -Wno-unused-result \
|
||||
-Wno-unused-but-set-variable -fPIC -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(tiff_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
|
||||
mindspore_add_pkg(tiff
|
||||
VER 4.1.0
|
||||
LIBS tiff
|
||||
URL https://gitlab.com/libtiff/libtiff/-/archive/v4.1.0/libtiff-v4.1.0.tar.gz
|
||||
MD5 21de8d35c1b21ac82663fa9f56d3350d
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -Djbig=OFF -Dlzma=OFF -Djpeg12=OFF -Dzstd=OFF -Dpixarlog=OFF
|
||||
-Dold-jpeg=OFF -Dwebp=OFF -DBUILD_SHARED_LIBS=OFF)
|
||||
message("tiff include = ${tiff_INC}")
|
||||
message("tiff lib = ${tiff_LIB}")
|
|
@ -0,0 +1,11 @@
|
|||
set(mkl_dnn_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(mkl_dnn_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(mkl_dnn
|
||||
VER 1.1.1
|
||||
LIBS dnnl mkldnn
|
||||
URL https://github.com/intel/mkl-dnn/archive/v1.1.1.tar.gz
|
||||
MD5 d6a422b00459600bdc22242590953f38
|
||||
CMAKE_OPTION -DDNNL_ARCH_OPT_FLAGS='' -DDNNL_CPU_RUNTIME='SEQ' -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF)
|
||||
include_directories(${mkl_dnn_INC})
|
||||
add_library(mindspore::dnnl ALIAS mkl_dnn::dnnl)
|
||||
add_library(mindspore::mkldnn ALIAS mkl_dnn::mkldnn)
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
set(nccl_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(nccl
|
||||
VER 2.4.8-1
|
||||
LIBS nccl
|
||||
URL https://github.com/NVIDIA/nccl/archive/v2.4.8-1.tar.gz
|
||||
MD5 f14b37d6af1c79db5f57cb029a753727
|
||||
BUILD_OPTION src.build NVCC_GENCODE="-gencode=arch=compute_70,code=sm_70"
|
||||
INSTALL_INCS build/include/*
|
||||
INSTALL_LIBS build/lib/*)
|
||||
include_directories(${nccl_INC})
|
||||
add_library(mindspore::nccl ALIAS nccl::nccl)
|
|
@ -0,0 +1,11 @@
|
|||
|
||||
set(ompi_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(ompi
|
||||
VER 3.1.5
|
||||
LIBS mpi
|
||||
URL https://github.com/open-mpi/ompi/archive/v3.1.5.tar.gz
|
||||
MD5 f7f220b26532c11a2efbc0bb73af3282
|
||||
PRE_CONFIGURE_COMMAND ./autogen.pl
|
||||
CONFIGURE_COMMAND ./configure)
|
||||
include_directories(${ompi_INC})
|
||||
add_library(mindspore::ompi ALIAS ompi::mpi)
|
|
@ -0,0 +1,5 @@
|
|||
mindspore_add_pkg(ms_onnx
|
||||
VER 1.6.0
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz
|
||||
MD5 512f2779d6215d4a36f366b6b9acdf1e)
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(opencv_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(opencv_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
|
||||
mindspore_add_pkg(opencv
|
||||
VER 4.2.0
|
||||
LIBS opencv_core opencv_imgcodecs opencv_imgproc
|
||||
URL https://github.com/opencv/opencv/archive/4.2.0.tar.gz
|
||||
MD5 e8cb208ce2723481408b604b480183b6
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DWITH_PROTOBUF=OFF -DWITH_WEBP=OFF -DWITH_IPP=OFF -DWITH_ADE=OFF
|
||||
-DBUILD_ZLIB=ON
|
||||
-DBUILD_JPEG=ON
|
||||
-DBUILD_PNG=ON
|
||||
-DBUILD_OPENEXR=ON
|
||||
-DBUILD_TESTS=OFF
|
||||
-DBUILD_PERF_TESTS=OFF
|
||||
-DBUILD_opencv_apps=OFF
|
||||
-DCMAKE_SKIP_RPATH=TRUE
|
||||
-DBUILD_opencv_python3=OFF
|
||||
-DWITH_FFMPEG=OFF
|
||||
-DWITH_TIFF=ON
|
||||
-DBUILD_TIFF=OFF
|
||||
-DWITH_JASPER=OFF
|
||||
-DBUILD_JASPER=OFF
|
||||
-DTIFF_INCLUDE_DIR=${tiff_INC}
|
||||
-DTIFF_LIBRARY=${tiff_LIB})
|
||||
include_directories(${opencv_INC}/opencv4)
|
||||
add_library(mindspore::opencv_core ALIAS opencv::opencv_core)
|
||||
add_library(mindspore::opencv_imgcodecs ALIAS opencv::opencv_imgcodecs)
|
||||
add_library(mindspore::opencv_imgproc ALIAS opencv::opencv_imgproc)
|
|
@ -0,0 +1,96 @@
|
|||
mindspore_add_pkg(protobuf
|
||||
VER 3.8.0
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
|
||||
MD5 3d9e32700639618a4d2d342c99d4507a)
|
||||
|
||||
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Disable protobuf test")
|
||||
set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "Gen shared library")
|
||||
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
|
||||
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
add_subdirectory(${protobuf_DIRPATH}/cmake ${protobuf_DIRPATH}/build)
|
||||
|
||||
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
|
||||
|
||||
set(PROTOBUF_LIBRARY protobuf::libprotobuf)
|
||||
include_directories(${protobuf_DIRPATH}/src)
|
||||
add_library(mindspore::protobuf ALIAS libprotobuf)
|
||||
|
||||
function(ms_protobuf_generate c_var h_var)
|
||||
if(NOT ARGN)
|
||||
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(${c_var})
|
||||
set(${h_var})
|
||||
|
||||
foreach(file ${ARGN})
|
||||
get_filename_component(abs_file ${file} ABSOLUTE)
|
||||
get_filename_component(file_name ${file} NAME_WE)
|
||||
get_filename_component(file_dir ${abs_file} PATH)
|
||||
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
|
||||
|
||||
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
|
||||
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
|
||||
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
|
||||
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
DEPENDS protobuf::protoc ${abs_file}
|
||||
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
|
||||
endforeach()
|
||||
|
||||
set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
|
||||
set(${c_var} ${${c_var}} PARENT_SCOPE)
|
||||
set(${h_var} ${${h_var}} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(ms_protobuf_generate_py c_var h_var py_var)
|
||||
if(NOT ARGN)
|
||||
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(${c_var})
|
||||
set(${h_var})
|
||||
set(${py_var})
|
||||
|
||||
foreach(file ${ARGN})
|
||||
get_filename_component(abs_file ${file} ABSOLUTE)
|
||||
get_filename_component(file_name ${file} NAME_WE)
|
||||
get_filename_component(file_dir ${abs_file} PATH)
|
||||
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
|
||||
|
||||
|
||||
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
|
||||
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
|
||||
list(APPEND ${py_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
|
||||
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
|
||||
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
|
||||
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
|
||||
COMMAND perl -pi -e "s/import (.+_pb2.*)/from . import \\1/" "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
|
||||
COMMAND cp "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py" "${PROJECT_SOURCE_DIR}/mindspore/train/"
|
||||
DEPENDS protobuf::protoc ${abs_file}
|
||||
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
|
||||
endforeach()
|
||||
|
||||
set_source_files_properties(${${c_var}} ${${h_var}} ${${py_var}} PROPERTIES GENERATED TRUE)
|
||||
set(${c_var} ${${c_var}} PARENT_SCOPE)
|
||||
set(${h_var} ${${h_var}} PARENT_SCOPE)
|
||||
set(${py_var} ${${py_var}} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
|
@ -0,0 +1,12 @@
|
|||
set(pybind11_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(pybind11_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(pybind11
|
||||
VER 2.4.3
|
||||
URL https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz
|
||||
MD5 62254c40f89925bb894be421fe4cdef2
|
||||
CMAKE_OPTION -DPYBIND11_TEST=OFF -DPYBIND11_LTO_CXX_FLAGS=FALSE
|
||||
)
|
||||
include_directories(${pybind11_INC})
|
||||
find_package(pybind11 REQUIRED)
|
||||
set_property(TARGET pybind11::module PROPERTY IMPORTED_GLOBAL TRUE)
|
||||
add_library(mindspore::pybind11_module ALIAS pybind11::module)
|
|
@ -0,0 +1,7 @@
|
|||
mindspore_add_pkg(rang
|
||||
VER 3.1.0
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/agauniyal/rang/archive/cabe04d6d6b05356fa8f9741704924788f0dd762.zip
|
||||
MD5 0c5c9b251fea9ee7ce32f188655be0ea)
|
||||
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
set(sqlite_USE_STATIC_LIBS ON)
|
||||
set(sqlite_CXXFLAGS)
|
||||
set(sqlite_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(sqlite_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
|
||||
mindspore_add_pkg(sqlite
|
||||
VER 3.31.1
|
||||
LIBS sqlite3
|
||||
URL https://github.com/sqlite/sqlite/archive/version-3.31.1.tar.gz
|
||||
MD5 5f4e7b4016c15f4fb5855615279819da
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch001
|
||||
CONFIGURE_COMMAND ./configure --enable-shared=no --disable-tcl --disable-editline --enable-json1)
|
||||
include_directories(${sqlite_INC})
|
||||
add_library(mindspore::sqlite ALIAS sqlite::sqlite3)
|
|
@ -0,0 +1,8 @@
|
|||
set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(incubator_tvm_gpu
|
||||
VER 0.6.0
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz
|
||||
MD5 9cbbd32545a776023acabbba270449fe)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
set(incubator_tvm_predict_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(incubator_tvm_predict_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(incubator_tvm_predict
|
||||
VER 0.6.0
|
||||
HEAD_ONLY ./
|
||||
URL https://github.com/apache/incubator-tvm/release/download/v0.6.0/apache-tvm-src-v0.6.0-incubating.tar.gz
|
||||
MD5 2d77a005f0046d937b99c67de82f6438
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch)
|
||||
include_directories(${incubator_tvm_predict_INC})
|
||||
add_library(mindspore::incubator_tvm_predict ALIAS incubator_tvm_predict)
|
|
@ -0,0 +1,59 @@
|
|||
set(SECURE_CXX_FLAGS "")
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
set(SECURE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
endif()
|
||||
set(_ms_tmp_CMAKE_CXX_FLAGS_F ${CMAKE_CXX_FLAGS})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
|
||||
include(cmake/utils.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pybind11.cmake)
|
||||
MESSAGE("go to link flatbuffers")
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake)
|
||||
if(USE_GLOG)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake)
|
||||
endif()
|
||||
|
||||
find_package(Python3)
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party)
|
||||
if (ENABLE_CPU)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake)
|
||||
endif()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/dlpack.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/dmlc_core.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/rang.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tvm_gpu.cmake)
|
||||
endif()
|
||||
|
||||
if (ENABLE_MPI)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
|
||||
endif()
|
||||
|
||||
if (ENABLE_GE)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external/graph)
|
||||
else()
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/ops)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external/graph)
|
||||
endif()
|
||||
|
||||
if (ENABLE_MINDDATA)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake)
|
||||
endif()
|
||||
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
|
||||
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS_F})
|
|
@ -0,0 +1,109 @@
|
|||
option(ENABLE_D "Enable d" OFF)
|
||||
option(ENABLE_GPU "Enable gpu" OFF)
|
||||
option(ENABLE_CPU "Enable cpu" OFF)
|
||||
option(ENABLE_GE "Enable graph engine as backend to execute" OFF)
|
||||
option(ENABLE_MINDDATA "Enable minddata compile" OFF)
|
||||
option(ENABLE_TRAIN "Enable ge train, default off(only infer)" OFF)
|
||||
option(ENABLE_TESTCASES "Run testcases switch, default off" OFF)
|
||||
option(DEBUG_MODE "Debug mode, default off" OFF)
|
||||
option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs")
|
||||
option(ENABLE_LOAD_ANF_IR "Enable load ANF-IR as input of 'infer' stage of pipeline" OFF)
|
||||
option(ENABLE_COVERAGE "Enable code coverage report" OFF)
|
||||
option(USE_GLOG "Use glog to output log" OFF)
|
||||
option(ENABLE_PROFILE "Enable pipeline profile, default off" OFF)
|
||||
option(ENABLE_TIMELINE "Enable time line record" OFF)
|
||||
option(ENABLE_DUMP_PROTO "Enable dump anf graph to file in ProtoBuffer format, default on" ON)
|
||||
option(ENABLE_DUMP_E2E "Enable dump e2e file, default on" OFF)
|
||||
option(ENABLE_DUMP_IR "Enable dump funciton graph ir, default on" ON)
|
||||
option(ENABLE_MPI "enable mpi" OFF)
|
||||
option(ENABLE_AKG "enable akg" OFF)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -Wsign-compare")
|
||||
endif()
|
||||
|
||||
if (ENABLE_COVERAGE)
|
||||
set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -ftest-coverage")
|
||||
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} ${COVERAGE_COMPILER_FLAGS}")
|
||||
endif()
|
||||
|
||||
if (ENABLE_ASAN)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -static-libasan -fsanitize=undefined")
|
||||
else()
|
||||
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -static-libsan -fsanitize=undefined")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (DEBUG_MODE)
|
||||
set(CMAKE_BUILD_TYPE "Debug")
|
||||
else()
|
||||
add_compile_definitions(MEM_REUSE_DEBUG)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
if ((CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") OR (CMAKE_BUILD_TYPE STREQUAL Release))
|
||||
set(PYBIND11_LTO_CXX_FLAGS FALSE)
|
||||
endif()
|
||||
|
||||
if (NOT BUILD_PATH)
|
||||
set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build")
|
||||
endif()
|
||||
|
||||
if (ENABLE_GE OR ENABLE_D)
|
||||
set(ENABLE_TDTQUE ON)
|
||||
endif()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
set(ENABLE_GPUQUE ON)
|
||||
endif()
|
||||
|
||||
if (ENABLE_GE)
|
||||
add_compile_definitions(ENABLE_GE)
|
||||
add_compile_definitions(CUSTOM_OP)
|
||||
endif()
|
||||
|
||||
if(ENABLE_TRAIN)
|
||||
add_compile_definitions(ENABLE_TRAIN=1)
|
||||
else()
|
||||
add_compile_definitions(ENABLE_TRAIN=0)
|
||||
endif()
|
||||
|
||||
if(USE_GLOG)
|
||||
add_compile_definitions(USE_GLOG)
|
||||
endif()
|
||||
|
||||
if (ENABLE_PROFILE)
|
||||
add_compile_definitions(ENABLE_PROFILE)
|
||||
endif()
|
||||
|
||||
if (ENABLE_TIMELINE)
|
||||
add_compile_definitions(ENABLE_TIMELINE)
|
||||
endif()
|
||||
|
||||
if (ENABLE_LOAD_ANF_IR)
|
||||
add_compile_definitions(ENABLE_LOAD_ANF_IR)
|
||||
endif()
|
||||
|
||||
if (ENABLE_TESTCASES OR (NOT ENABLE_D AND NOT ENABLE_GE))
|
||||
add_compile_definitions(NO_DLIB=1)
|
||||
endif()
|
||||
|
||||
if(ENABLE_DUMP_IR)
|
||||
add_compile_definitions(ENABLE_DUMP_IR)
|
||||
endif(ENABLE_DUMP_IR)
|
||||
|
||||
if(ENABLE_MINDDATA)
|
||||
add_compile_definitions(ENABLE_MINDDATA)
|
||||
if (ENABLE_TDTQUE)
|
||||
add_compile_definitions(ENABLE_TDTQUE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_DUMP_E2E)
|
||||
add_compile_definitions(ENABLE_DUMP_E2E)
|
||||
endif()
|
|
@ -0,0 +1,349 @@
|
|||
include(FetchContent)
|
||||
set(FETCHCONTENT_QUIET OFF)
|
||||
|
||||
function(mindspore_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj)
|
||||
|
||||
add_subdirectory(${sub_dir})
|
||||
|
||||
if(NOT TARGET ${submodule_name_obj})
|
||||
message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}")
|
||||
endif()
|
||||
if("$<TARGET_OBJECTS:${submodule_name_obj}>" IN_LIST ${des_submodule_objs})
|
||||
message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}")
|
||||
endif()
|
||||
|
||||
set(${des_submodule_objs} ${${des_submodule_objs}} $<TARGET_OBJECTS:${submodule_name_obj}> PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
get_filename_component(_MS_LIB_CACHE ~/.mslib REALPATH)
|
||||
if (NOT EXISTS ${_MS_LIB_CACHE})
|
||||
file(MAKE_DIRECTORY ${_MS_LIB_CACHE})
|
||||
endif ()
|
||||
# set(FETCHCONTENT_BASE_DIR ${_MS_LIB_CACHE})
|
||||
# set(CMAKE_PREFIX_PATH ${_MS_LIB_CACHE})
|
||||
if (DEFINED ENV{MSLIBS_SERVER})
|
||||
set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER})
|
||||
message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}")
|
||||
endif ()
|
||||
if(LOCAL_LIBS_SERVER)
|
||||
if (NOT ENV{no_proxy})
|
||||
set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}")
|
||||
else()
|
||||
string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS)
|
||||
if (${IP_POS} EQUAL -1)
|
||||
set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}")
|
||||
endif ()
|
||||
endif ()
|
||||
endif()
|
||||
|
||||
function(__download_pkg pkg_name pkg_url pkg_md5)
|
||||
|
||||
if(LOCAL_LIBS_SERVER)
|
||||
get_filename_component(_URL_FILE_NAME ${pkg_url} NAME)
|
||||
set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url})
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
${pkg_name}
|
||||
URL ${pkg_url}
|
||||
URL_HASH MD5=${pkg_md5}
|
||||
)
|
||||
FetchContent_GetProperties(${pkg_name})
|
||||
message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}")
|
||||
if(NOT ${pkg_name}_POPULATED)
|
||||
FetchContent_Populate(${pkg_name})
|
||||
set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(__download_pkg_with_git pkg_name pkg_url pkg_git_commit pkg_md5)
|
||||
|
||||
if(LOCAL_LIBS_SERVER)
|
||||
set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${pkg_git_commit}")
|
||||
FetchContent_Declare(
|
||||
${pkg_name}
|
||||
URL ${pkg_url}
|
||||
URL_HASH MD5=${pkg_md5}
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
${pkg_name}
|
||||
GIT_REPOSITORY ${pkg_url}
|
||||
GIT_TAG ${pkg_git_commit})
|
||||
endif()
|
||||
FetchContent_GetProperties(${pkg_name})
|
||||
message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}")
|
||||
if(NOT ${pkg_name}_POPULATED)
|
||||
FetchContent_Populate(${pkg_name})
|
||||
set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
|
||||
function(__find_pkg_then_add_target pkg_name pkg_exe)
|
||||
|
||||
unset(${pkg_name}_LIBS)
|
||||
|
||||
message("_FIND:${${pkg_name}_BASE_DIR}")
|
||||
|
||||
if(pkg_exe)
|
||||
find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH)
|
||||
if(NOT ${pkg_exe}_EXE)
|
||||
return()
|
||||
endif()
|
||||
add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL)
|
||||
set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES
|
||||
IMPORTED_LOCATION ${${pkg_exe}_EXE}
|
||||
)
|
||||
message("found ${${pkg_exe}_EXE}")
|
||||
endif()
|
||||
|
||||
foreach(_LIB_NAME ${ARGN})
|
||||
set(_LIB_SEARCH_NAME ${_LIB_NAME})
|
||||
set(_LIB_TYPE SHARED)
|
||||
if (${pkg_name}_USE_STATIC_LIBS)
|
||||
set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
set(_LIB_TYPE STATIC)
|
||||
endif ()
|
||||
set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND)
|
||||
find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/lib NO_DEFAULT_PATH)
|
||||
if(NOT ${_LIB_NAME}_LIB)
|
||||
return()
|
||||
endif()
|
||||
add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL)
|
||||
set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include"
|
||||
IMPORTED_LOCATION ${${_LIB_NAME}_LIB}
|
||||
)
|
||||
list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME})
|
||||
message("found ${${_LIB_NAME}_LIB}")
|
||||
STRING( REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB})
|
||||
set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL)
|
||||
endforeach(_LIB_NAME)
|
||||
|
||||
set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
function(__exec_cmd)
|
||||
set(options )
|
||||
set(oneValueArgs WORKING_DIRECTORY)
|
||||
set(multiValueArgs COMMAND)
|
||||
|
||||
cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )
|
||||
|
||||
execute_process(COMMAND ${EXEC_COMMAND}
|
||||
WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY}
|
||||
RESULT_VARIABLE RESULT)
|
||||
if(NOT RESULT EQUAL "0")
|
||||
message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(__check_patches pkg_patches)
|
||||
# check patches
|
||||
if (PKG_PATCHES)
|
||||
file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.md5)
|
||||
file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${pkg_name}_PATCHES_MD5)
|
||||
|
||||
message("patches md5:${${pkg_name}_PATCHES_MD5}")
|
||||
|
||||
set(${pkg_name}_PATCHES_NEW_MD5 )
|
||||
foreach(_PATCH ${PKG_PATCHES})
|
||||
file(MD5 ${_PATCH} _PF_MD5)
|
||||
set(${pkg_name}_PATCHES_NEW_MD5 "${${pkg_name}_PATCHES_NEW_MD5},${_PF_MD5}")
|
||||
endforeach(_PATCH)
|
||||
|
||||
if (NOT ${pkg_name}_PATCHES_MD5 STREQUAL ${pkg_name}_PATCHES_NEW_MD5)
|
||||
set(${pkg_name}_PATCHES ${PKG_PATCHES})
|
||||
file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild")
|
||||
file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${${pkg_name}_PATCHES_NEW_MD5})
|
||||
message("patches changed : ${${pkg_name}_PATCHES_NEW_MD5}")
|
||||
endif ()
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
set(MS_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH
|
||||
NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH
|
||||
NO_CMAKE_SYSTEM_PACKAGE_REGISTRY)
|
||||
set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE)
|
||||
function(mindspore_add_pkg pkg_name )
|
||||
|
||||
set(options )
|
||||
set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY)
|
||||
set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES)
|
||||
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )
|
||||
|
||||
set(__FIND_PKG_NAME ${pkg_name})
|
||||
string(TOLOWER ${pkg_name} pkg_name)
|
||||
message("pkg name:${__FIND_PKG_NAME},${pkg_name}")
|
||||
|
||||
set(${pkg_name}_PATCHES_HASH )
|
||||
foreach(_PATCH ${PKG_PATCHES})
|
||||
file(MD5 ${_PATCH} _PF_MD5)
|
||||
set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_MD5}")
|
||||
endforeach(_PATCH)
|
||||
|
||||
# check options
|
||||
set(${pkg_name}_CONFIG_TXT
|
||||
"${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION}
|
||||
${ARGN} - ${${pkg_name}_USE_STATIC_LIBS}- ${${pkg_name}_PATCHES_HASH}
|
||||
${${pkg_name}_CXXFLAGS}--${${pkg_name}_CFLAGS}--${${pkg_name}_LDFLAGS}")
|
||||
string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT})
|
||||
string(MD5 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT})
|
||||
|
||||
message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}")
|
||||
|
||||
set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${${pkg_name}_CONFIG_HASH})
|
||||
set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL)
|
||||
|
||||
if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY)
|
||||
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE)
|
||||
add_library(${pkg_name} INTERFACE)
|
||||
target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC})
|
||||
return()
|
||||
endif ()
|
||||
|
||||
if(NOT PKG_EXE)
|
||||
set(PKG_EXE 0)
|
||||
endif()
|
||||
|
||||
set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR})
|
||||
set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE)
|
||||
|
||||
if (PKG_LIBS)
|
||||
__find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS})
|
||||
if(${pkg_name}_LIBS)
|
||||
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
|
||||
message("Found libs: ${${pkg_name}_LIBS}")
|
||||
return()
|
||||
endif()
|
||||
elseif(NOT PKG_HEAD_ONLY)
|
||||
find_package(${__FIND_PKG_NAME} ${PKG_VER} ${MS_FIND_NO_DEFAULT_PATH})
|
||||
if (${__FIND_PKG_NAME}_FOUND)
|
||||
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
|
||||
message("Found pkg: ${__FIND_PKG_NAME}")
|
||||
return()
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT PKG_DIR)
|
||||
if (PKG_GIT_REPOSITORY)
|
||||
__download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5})
|
||||
else()
|
||||
__download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5})
|
||||
endif()
|
||||
else()
|
||||
set(${pkg_name}_SOURCE_DIR ${PKG_DIR})
|
||||
endif ()
|
||||
file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT})
|
||||
message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}")
|
||||
|
||||
foreach(_PATCH_FILE ${PKG_PATCHES})
|
||||
message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_PATCH_FILE}")
|
||||
execute_process(COMMAND patch -p1 INPUT_FILE ${_PATCH_FILE}
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}
|
||||
RESULT_VARIABLE Result)
|
||||
if(NOT Result EQUAL "0")
|
||||
message(FATAL_ERROR "Failed patch: ${_PATCH_FILE}")
|
||||
endif()
|
||||
endforeach(_PATCH_FILE)
|
||||
|
||||
file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600)
|
||||
if(NOT ${pkg_name}_LOCK_RET EQUAL "0")
|
||||
message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}")
|
||||
endif()
|
||||
|
||||
if(${pkg_name}_SOURCE_DIR)
|
||||
if (PKG_HEAD_ONLY)
|
||||
file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*)
|
||||
file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR})
|
||||
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE)
|
||||
add_library(${pkg_name} INTERFACE)
|
||||
target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC})
|
||||
|
||||
elseif (PKG_CMAKE_OPTION)
|
||||
# in cmake
|
||||
file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
|
||||
if (${pkg_name}_CFLAGS)
|
||||
set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}")
|
||||
endif ()
|
||||
if (${pkg_name}_CXXFLAGS)
|
||||
set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}")
|
||||
endif ()
|
||||
|
||||
if (${pkg_name}_LDFLAGS)
|
||||
if (${pkg_name}_USE_STATIC_LIBS)
|
||||
#set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}")
|
||||
else()
|
||||
set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
__exec_cmd(COMMAND ${CMAKE_COMMAND} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR}
|
||||
${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS}
|
||||
-DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ..
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
|
||||
|
||||
__exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j8
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
|
||||
|
||||
else()
|
||||
if (${pkg_name}_CFLAGS)
|
||||
set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}")
|
||||
endif ()
|
||||
if (${pkg_name}_CXXFLAGS)
|
||||
set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}")
|
||||
endif ()
|
||||
if (${pkg_name}_LDFLAGS)
|
||||
set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}")
|
||||
endif ()
|
||||
# in configure && make
|
||||
if (PKG_PRE_CONFIGURE_COMMAND)
|
||||
__exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND}
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
|
||||
endif ()
|
||||
|
||||
if (PKG_CONFIGURE_COMMAND)
|
||||
__exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND}
|
||||
${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}
|
||||
--prefix=${${pkg_name}_BASE_DIR}
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
|
||||
endif ()
|
||||
set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION})
|
||||
if (NOT PKG_CONFIGURE_COMMAND)
|
||||
set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION}
|
||||
${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS})
|
||||
endif ()
|
||||
# build
|
||||
__exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j8
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
|
||||
|
||||
if (PKG_INSTALL_INCS OR PKG_INSTALL_LIBS)
|
||||
file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS})
|
||||
file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS})
|
||||
file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include)
|
||||
file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib)
|
||||
else()
|
||||
__exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
|
||||
endif ()
|
||||
endif ()
|
||||
endif()
|
||||
|
||||
if (PKG_LIBS)
|
||||
__find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS})
|
||||
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
|
||||
if(NOT ${pkg_name}_LIBS)
|
||||
message(FATAL_ERROR "Can not find pkg: ${pkg_name}")
|
||||
endif()
|
||||
else()
|
||||
find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET ${MS_FIND_NO_DEFAULT_PATH})
|
||||
if (${__FIND_PKG_NAME}_FOUND)
|
||||
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
|
||||
message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}")
|
||||
return()
|
||||
endif ()
|
||||
endif ()
|
||||
endfunction()
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"DumpSettings": {
|
||||
"enable": false,
|
||||
"trans_flag": false,
|
||||
"path": "/tmp/net/",
|
||||
"net_name": "ResNet50",
|
||||
"mode": 0,
|
||||
"iteration": 0,
|
||||
"kernels": ["TensorAdd"]
|
||||
},
|
||||
|
||||
"DumpSettingsSpec": {
|
||||
"enable": "true: dump enable false: dump disable",
|
||||
"trans_flag": "true: trans to host format,false: not trans format",
|
||||
"path": "the dump file folder",
|
||||
"net_name": "net name eg:ResNet50",
|
||||
"mode": "0: dump all kernels 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration others: specified iteration ",
|
||||
"kernels": "kernel name list need to be dump"
|
||||
},
|
||||
"other": {}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"DumpSettings": {
|
||||
"enable": false,
|
||||
"trans_flag": false,
|
||||
"path": "/tmp/hccllog/0",
|
||||
"net_name": "ResNet50",
|
||||
"mode": 0,
|
||||
"iteration": 0,
|
||||
"kernels": ["AllReduce","BiasAddGrad","Conv2DBackpropFilter","SparseSoftmaxCrossEntropyWithLogits"]
|
||||
},
|
||||
|
||||
"DumpSettingsSpec": {
|
||||
"enable": "true: dump enable false: dump disable",
|
||||
"trans_flag": "true: trans to host format,false: not trans format",
|
||||
"path": "the dump file folder",
|
||||
"net_name": "net name eg:ResNet50",
|
||||
"mode": "0: dump all kernels 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration others: specified iteration ",
|
||||
"kernels": "kernel name list need to be dump"
|
||||
},
|
||||
"other": {}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"DumpSettings": {
|
||||
"enable": false,
|
||||
"trans_flag": false,
|
||||
"path": "/tmp/hccllog/1",
|
||||
"net_name": "ResNet50",
|
||||
"mode": 0,
|
||||
"iteration": 0,
|
||||
"kernels": ["AllReduce","BiasAddGrad","Conv2DBackpropFilter","SparseSoftmaxCrossEntropyWithLogits"]
|
||||
},
|
||||
|
||||
"DumpSettingsSpec": {
|
||||
"enable": "true: dump enable false: dump disable",
|
||||
"trans_flag": "true: trans to host format,false: not trans format",
|
||||
"path": "the dump file folder",
|
||||
"net_name": "net name eg:ResNet50",
|
||||
"mode": "0: dump all kernels 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration others: specified iteration ",
|
||||
"kernels": "kernel name list need to be dump"
|
||||
},
|
||||
"other": {}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"board_id": "0x3000",
|
||||
"chip_info": "910",
|
||||
"deploy_mode": "lab",
|
||||
"group_count": "1",
|
||||
"group_list": [
|
||||
{
|
||||
"device_num": "2",
|
||||
"server_num": "1",
|
||||
"group_name": "",
|
||||
"instance_count": "2",
|
||||
"instance_list": [
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "0",
|
||||
"device_ip": "[device_ip]"
|
||||
}
|
||||
],
|
||||
"rank_id": "0",
|
||||
"server_id": "[sever_id]"
|
||||
},
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "1",
|
||||
"device_ip": "[device_ip]"
|
||||
}
|
||||
],
|
||||
"rank_id": "1",
|
||||
"server_id": "[sever_id]"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"para_plane_nic_location": "device",
|
||||
"para_plane_nic_name": [
|
||||
"eth0",
|
||||
"eth1"
|
||||
],
|
||||
"para_plane_nic_num": "2",
|
||||
"status": "completed"
|
||||
}
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
{
|
||||
"board_id": "0x0000",
|
||||
"chip_info": "910",
|
||||
"deploy_mode": "lab",
|
||||
"group_count": "1",
|
||||
"group_list": [{
|
||||
"device_num": "16",
|
||||
"server_num": "2",
|
||||
"group_name": "",
|
||||
"instance_count": "16",
|
||||
"instance_list": [{
|
||||
"devices": [{
|
||||
"device_id": "0",
|
||||
"device_ip": "[A_device_ip_0]"
|
||||
}],
|
||||
"rank_id": "0",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "1",
|
||||
"device_ip": "[A_device_ip_1]"
|
||||
}],
|
||||
"rank_id": "1",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "2",
|
||||
"device_ip": "[A_device_ip_2]"
|
||||
}],
|
||||
"rank_id": "2",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "3",
|
||||
"device_ip": "[A_device_ip_3]"
|
||||
}],
|
||||
"rank_id": "3",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "4",
|
||||
"device_ip": "[A_device_ip_4]"
|
||||
}],
|
||||
"rank_id": "4",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "5",
|
||||
"device_ip": "[A_device_ip_5]"
|
||||
}],
|
||||
"rank_id": "5",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "6",
|
||||
"device_ip": "[A_device_ip_6]"
|
||||
}],
|
||||
"rank_id": "6",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "7",
|
||||
"device_ip": "[A_device_ip_7]"
|
||||
}],
|
||||
"rank_id": "7",
|
||||
"server_id": "[server_id_A]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "0",
|
||||
"device_ip": "[B_device_ip_0]"
|
||||
}],
|
||||
"rank_id": "8",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "1",
|
||||
"device_ip": "[B_device_ip_1]"
|
||||
}],
|
||||
"rank_id": "9",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "2",
|
||||
"device_ip": "[B_device_ip_2]"
|
||||
}],
|
||||
"rank_id": "10",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "3",
|
||||
"device_ip": "[B_device_ip_3]"
|
||||
}],
|
||||
"rank_id": "11",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "4",
|
||||
"device_ip": "[B_device_ip_4]"
|
||||
}],
|
||||
"rank_id": "12",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "5",
|
||||
"device_ip": "[B_device_ip_5]"
|
||||
}],
|
||||
"rank_id": "13",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "6",
|
||||
"device_ip": "[B_device_ip_6]"
|
||||
}],
|
||||
"rank_id": "14",
|
||||
"server_id": "[server_id_B]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "7",
|
||||
"device_ip": "[B_device_ip_7]"
|
||||
}],
|
||||
"rank_id": "15",
|
||||
"server_id": "[server_id_B]"
|
||||
}
|
||||
]
|
||||
}],
|
||||
"para_plane_nic_location": "device",
|
||||
"para_plane_nic_name": [
|
||||
"eth0",
|
||||
"eth1",
|
||||
"eth2",
|
||||
"eth3",
|
||||
"eth4",
|
||||
"eth5",
|
||||
"eth6",
|
||||
"eth7"
|
||||
],
|
||||
"para_plane_nic_num": "8",
|
||||
"status": "completed",
|
||||
|
||||
"hccl_config_json_spec": {
|
||||
"board_id": "board id, current support x0000 or 0x3000",
|
||||
"chip_info": "chip info, current is 910",
|
||||
"deploy_mode": "current use lab",
|
||||
"group_count": "number of groups used",
|
||||
"group_list": "detailed group information",
|
||||
"device_num": "number of devices used, the value is the nth power of 2",
|
||||
"server_num": "number of multiple machines, single machine is 1",
|
||||
"group_name": "default is hccl_world_group or specified",
|
||||
"instance_count": "number of instance used, generally equal to device_num",
|
||||
"instance_list": "detailed instance information",
|
||||
"device_id": "designated davinic device id to use, values start from 0, but no more than single machine total device num.if server_num greater than 1, the id can be restart from 0",
|
||||
"device_ip": "ip corresponding to device_id",
|
||||
"rank_id": "the first device must be 0 and then increase in order",
|
||||
"server_id": "can be specified as the machine's ip address",
|
||||
"para_plane_nic_location": "current use device",
|
||||
"para_plane_nic_name": "network card corresponding to device ip",
|
||||
"para_plane_nic_num": "number of network cards used",
|
||||
"status": "current use completed"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
{
|
||||
"board_id": "0x0000",
|
||||
"chip_info": "910",
|
||||
"deploy_mode": "lab",
|
||||
"group_count": "1",
|
||||
"group_list": [{
|
||||
"device_num": "8",
|
||||
"server_num": "1",
|
||||
"group_name": "",
|
||||
"instance_count": "8",
|
||||
"instance_list": [{
|
||||
"devices": [{
|
||||
"device_id": "0",
|
||||
"device_ip": "[device_ip_0]"
|
||||
}],
|
||||
"rank_id": "0",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "1",
|
||||
"device_ip": "[device_ip_1]"
|
||||
}],
|
||||
"rank_id": "1",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "2",
|
||||
"device_ip": "[device_ip_2]"
|
||||
}],
|
||||
"rank_id": "2",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "3",
|
||||
"device_ip": "[device_ip_3]"
|
||||
}],
|
||||
"rank_id": "3",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "4",
|
||||
"device_ip": "[device_ip_4]"
|
||||
}],
|
||||
"rank_id": "4",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "5",
|
||||
"device_ip": "[device_ip_5]"
|
||||
}],
|
||||
"rank_id": "5",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "6",
|
||||
"device_ip": "[device_ip_6]"
|
||||
}],
|
||||
"rank_id": "6",
|
||||
"server_id": "[server_id]"
|
||||
},
|
||||
{
|
||||
"devices": [{
|
||||
"device_id": "7",
|
||||
"device_ip": "[device_ip_7]"
|
||||
}],
|
||||
"rank_id": "7",
|
||||
"server_id": "[server_id]"
|
||||
}
|
||||
]
|
||||
}],
|
||||
"para_plane_nic_location": "device",
|
||||
"para_plane_nic_name": [
|
||||
"eth0",
|
||||
"eth1",
|
||||
"eth2",
|
||||
"eth3",
|
||||
"eth4",
|
||||
"eth5",
|
||||
"eth6",
|
||||
"eth7"
|
||||
],
|
||||
"para_plane_nic_num": "8",
|
||||
"status": "completed",
|
||||
|
||||
"hccl_config_json_spec": {
|
||||
"board_id": "board id, current support x0000 or 0x3000",
|
||||
"chip_info": "chip info, current is 910",
|
||||
"deploy_mode": "current use lab",
|
||||
"group_count": "number of groups used",
|
||||
"group_list": "detailed group information",
|
||||
"device_num": "number of devices used, the value is the nth power of 2",
|
||||
"server_num": "single machine is 1",
|
||||
"group_name": "default is hccl_world_group or specified",
|
||||
"instance_count": "number of instance used, generally equal to device_num",
|
||||
"instance_list": "detailed instance information",
|
||||
"device_id": "designated davinic device id to use, values start from 0, but no more than single machine total device num",
|
||||
"device_ip": "ip corresponding to device_id",
|
||||
"rank_id": "the first device must be 0 and then increase in order",
|
||||
"server_id": "can be specified as the machine's ip address",
|
||||
"para_plane_nic_location": "current use device",
|
||||
"para_plane_nic_name": "network card corresponding to device ip",
|
||||
"para_plane_nic_num": "number of network cards used",
|
||||
"status": "current use completed"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
set -x
|
||||
set -e
|
||||
|
||||
export SAVE_GRAPHS=YES
|
||||
|
||||
# print usage message
|
||||
function usage()
|
||||
{
|
||||
echo "Usage:"
|
||||
echo "bash $0 [-g] [-d] [-a] [-h] [-f file]"
|
||||
echo "e.g. $0 -f 3_specialize.dat"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -g Generate ir file for debug"
|
||||
echo " -d Debug dumped ir"
|
||||
echo " -a Execute all steps, default"
|
||||
echo " -f File to be parse"
|
||||
echo " -h Print usage"
|
||||
}
|
||||
|
||||
# check and set options
|
||||
function checkopts()
|
||||
{
|
||||
# init variable
|
||||
MODE_GEN=0
|
||||
MODE_DBG=1
|
||||
MODE_ALL=2
|
||||
FILE_NAME="3_optimize.dat"
|
||||
mode="${MODE_ALL}" # default execute all steps
|
||||
|
||||
# Process the options
|
||||
while getopts 'gdaf:h' opt
|
||||
do
|
||||
case "${opt}" in
|
||||
g)
|
||||
mode="${MODE_GEN}"
|
||||
;;
|
||||
d)
|
||||
mode="${MODE_DBG}"
|
||||
;;
|
||||
a)
|
||||
mode="${MODE_ALL}"
|
||||
;;
|
||||
f)
|
||||
FILE_NAME="$OPTARG"
|
||||
if ! [ -f "${FILE_NAME}" ]; then
|
||||
echo "File $FILE_NAME does not exist"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
usage
|
||||
exit 1
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
# init variable
|
||||
# check options
|
||||
checkopts "$@"
|
||||
|
||||
|
||||
cd build/mindspore/
|
||||
make -j8
|
||||
cp -v mindspore/ccsrc/_c_expression.cpython-*.so ../../mindspore/
|
||||
cd -
|
||||
|
||||
UT_NAME="./tests/ut/python/model/test_lenet.py::test_lenet5_train_sens"
|
||||
#UT_NAME="./tests/python/ops/test_math_ops.py::test_matmul_grad"
|
||||
#UT_NAME="./tests/python/exec/resnet_example.py::test_compile"
|
||||
#UT_NAME="./tests/perf_test/test_bert_train.py::test_bert_train"
|
||||
|
||||
if [[ "${mode}" == "${MODE_GEN}" || "${mode}" == "${MODE_ALL}" ]]; then
|
||||
rm -rf pkl_objs
|
||||
mkdir -p pkl_objs
|
||||
|
||||
echo "MS_IR_PATH=$(pwd)/pkl_objs pytest -s ${UT_NAME}"
|
||||
MS_IR_PATH=$(pwd)/pkl_objs/ pytest -s "${UT_NAME}"
|
||||
#pytest -s $UT_NAME
|
||||
|
||||
# 1_resolve.dat
|
||||
# 3_specialize.dat
|
||||
# 4_simplify_data_structures.dat
|
||||
# 5_opt.dat
|
||||
# 6_opt2.dat
|
||||
# 7_opt_ge_adaptor_special.dat
|
||||
# 8_cconv.dat
|
||||
# 9_validate.dat
|
||||
cp "${FILE_NAME}" anf_ir_file.dbg
|
||||
|
||||
rm -rf pkl_objs.dbg
|
||||
cp -rf pkl_objs pkl_objs.dbg
|
||||
fi
|
||||
|
||||
if [[ "${mode}" == "${MODE_DBG}" || "${mode}" == "${MODE_ALL}" ]]; then
|
||||
echo "MS_IR_FILE=$(pwd)/anf_ir_file.dbg MS_IR_PATH=$(pwd)/pkl_objs.dbg/ pytest -s ${UT_NAME}"
|
||||
MS_IR_FILE=$(pwd)/anf_ir_file.dbg MS_IR_PATH=$(pwd)/pkl_objs.dbg/ pytest -s "${UT_NAME}"
|
||||
fi
|
Binary file not shown.
After Width: | Height: | Size: 76 KiB |
Binary file not shown.
After Width: | Height: | Size: 121 KiB |
Binary file not shown.
After Width: | Height: | Size: 53 KiB |
Binary file not shown.
After Width: | Height: | Size: 13 KiB |
|
@ -0,0 +1,3 @@
|
|||
# MindSpore Documentation
|
||||
|
||||
The MindSpore documentation is in the [MindSpore Docs](https://gitee.com/mindspore/docs) repository.
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
bert_cfg = edict({
|
||||
'epoch_size': 10,
|
||||
'num_warmup_steps': 0,
|
||||
'start_learning_rate': 1e-4,
|
||||
'end_learning_rate': 1,
|
||||
'decay_steps': 1000,
|
||||
'power': 10.0,
|
||||
'save_checkpoint_steps': 2000,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_prefix': "checkpoint_bert",
|
||||
'DATA_DIR' = "/your/path/examples.tfrecord"
|
||||
'SCHEMA_DIR' = "/your/path/datasetSchema.json"
|
||||
'bert_config': BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
})
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
|
||||
1. Prepare data
|
||||
Following the data preparation as in BERT, run command as below to get dataset for training:
|
||||
python ./create_pretraining_data.py \
|
||||
--input_file=./sample_text.txt \
|
||||
--output_file=./examples.tfrecord \
|
||||
--vocab_file=./your/path/vocab.txt \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=128 \
|
||||
--max_predictions_per_seq=20 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=5
|
||||
2. Pretrain
|
||||
First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy import allclose
|
||||
from config import bert_cfg as cfg
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore._c_dataengine as deMap
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore import log as logger
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
DATA_DIR = [cfg.DATA_DIR]
|
||||
SCHEMA_DIR = cfg.SCHEMA_DIR
|
||||
|
||||
def me_de_train_dataset(batch_size):
|
||||
"""test me de train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = cfg.epoch_size
|
||||
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"next_sentence_labels", "masked_lm_positions",
|
||||
"masked_lm_ids", "masked_lm_weights"])
|
||||
type_cast_op = deMap.TypeCastOp("int32")
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
return ds
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
"""weight variable"""
|
||||
np.random.seed(1)
|
||||
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
|
||||
class ModelCallback(Callback):
|
||||
def __init__(self):
|
||||
super(ModelCallback, self).__init__()
|
||||
self.loss_list = []
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
self.loss_list.append(cb_params.net_outputs.asnumpy()[0])
|
||||
logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
||||
|
||||
def test_bert_tdt():
|
||||
"""test bert tdt"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_context(enable_task_sink=True)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
parallel_callback = ModelCallback()
|
||||
ds = me_de_train_dataset(cfg.bert_config.batch_size)
|
||||
config = cfg.bert_config
|
||||
netwithloss = BertNetworkWithLoss(config, True)
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate,
|
||||
end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False)
|
||||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck)
|
||||
model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bert_tdt()
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
mnist_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'lr': 0.01,
|
||||
'momentum': 0.9,
|
||||
'epoch_size': 1,
|
||||
'batch_size': 32,
|
||||
'repeat_size': 1,
|
||||
'buffer_size': 1000,
|
||||
'image_height': 32,
|
||||
'image_width': 32,
|
||||
'save_checkpoint_steps': 1875,
|
||||
'keep_checkpoint_max': 10,
|
||||
})
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train and test lenet example ########################
|
||||
1. train lenet and get network model files(.ckpt) :
|
||||
python main.py --data_path /home/workspace/mindspore_dataset/Tutorial_Network/Lenet/MNIST_Data
|
||||
|
||||
2. test lenet according to model file:
|
||||
python main.py --data_path /home/workspace/mindspore_dataset/Tutorial_Network/Lenet/MNIST_Data
|
||||
--mode test --ckpt_path checkpoint_lenet_1-1_1875.ckpt
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from config import mnist_cfg as cfg
|
||||
|
||||
import mindspore.dataengine as de
|
||||
import mindspore.nn as nn
|
||||
from mindspore.model_zoo.lenet import LeNet5
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train import Model
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.transforms.c_transforms as C
|
||||
from mindspore.transforms import Inter
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
"""
|
||||
Define loss for network
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss = self.cross_entropy(logits, label)[0]
|
||||
loss = self.mean(loss, (-1,))
|
||||
return loss
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
# define dataset
|
||||
ds1 = de.MnistDataset(data_path)
|
||||
|
||||
# apply map operations on images
|
||||
ds1 = ds1.map(input_columns="label", operations=C.TypeCast(mstype.int32))
|
||||
ds1 = ds1.map(input_columns="image", operations=C.Resize((cfg.image_height, cfg.image_width),
|
||||
interpolation=Inter.LINEAR),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds1 = ds1.map(input_columns="image", operations=C.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds1 = ds1.map(input_columns="image", operations=C.Rescale(1.0 / 255.0, 0.0),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds1 = ds1.map(input_columns="image", operations=C.HWC2CHW(), num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
ds1 = ds1.shuffle(buffer_size=cfg.buffer_size) # 10000 as in LeNet train script
|
||||
ds1 = ds1.batch(batch_size, drop_remainder=True)
|
||||
ds1 = ds1.repeat(repeat_size)
|
||||
|
||||
return ds1
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'],
|
||||
help='implement phase, set to train or test')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
network.set_train()
|
||||
# net_loss = nn.SoftmaxCrossEntropyWithLogits() # support this loss soon
|
||||
net_loss = CrossEntropyLoss()
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
if args.mode == 'train': # train
|
||||
ds = create_dataset(os.path.join(args.data_path, args.mode), batch_size=cfg.batch_size,
|
||||
repeat_size=cfg.epoch_size)
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
|
||||
elif args.mode == 'test': # test
|
||||
print("============== Starting Testing ==============")
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
ds_eval = create_dataset(os.path.join(args.data_path, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
||||
else:
|
||||
raise RuntimeError('mode should be train or test, rather than {}'.format(args.mode))
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"class_num": 10,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 90,
|
||||
"buffer_size": 100,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 195,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./",
|
||||
"lr_init": 0.01,
|
||||
"lr_end": 0.00001,
|
||||
"lr_max": 0.1,
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "poly"
|
||||
})
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from config import config
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
device_num = int(os.getenv("DEVICE_NUM"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True)
|
||||
else:
|
||||
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
# define map operations
|
||||
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4))
|
||||
random_horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1))
|
||||
|
||||
resize_op = C.Resize((resize_height, resize_width))
|
||||
rescale_op = C.Rescale(rescale, shift)
|
||||
normalize_op = C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
|
||||
|
||||
change_swap_op = C.HWC2CHW()
|
||||
|
||||
trans = []
|
||||
if do_train:
|
||||
trans += [random_crop_op, random_horizontal_flip_op]
|
||||
|
||||
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="image", operations=trans)
|
||||
|
||||
# apply shuffle operations
|
||||
ds = ds.shuffle(buffer_size=config.buffer_size)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
eval.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
from dataset import create_dataset
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore.model_zoo.resnet import resnet50
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.communication.management import init
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
|
||||
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(enable_task_sink=True, device_id=device_id)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.do_eval:
|
||||
context.set_context(enable_hccl=False)
|
||||
else:
|
||||
if args_opt.run_distribute:
|
||||
context.set_context(enable_hccl=True)
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
|
||||
init()
|
||||
else:
|
||||
context.set_context(enable_hccl=False)
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
|
||||
if args_opt.do_eval:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
model = Model(net, loss_fn=loss, metrics={'acc'})
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: DMINDSPORE_HCCL_CONFIG_PATH=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $2 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp *.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "infer" ];
|
||||
then
|
||||
rm -rf ./infer
|
||||
fi
|
||||
mkdir ./infer
|
||||
cp *.py ./infer
|
||||
cp *.sh ./infer
|
||||
cd ./infer || exit
|
||||
env > env.log
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
python eval.py --do_eval=True --dataset_path=$1 --checkpoint_path=$2 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp *.py ./train
|
||||
cp *.sh ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --do_train=True --dataset_path=$1 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
from dataset import create_dataset
|
||||
from lr_generator import get_lr
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.model_zoo.resnet import resnet50
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.communication.management import init
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
|
||||
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(enable_task_sink=True, device_id=device_id)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.do_eval:
|
||||
context.set_context(enable_hccl=False)
|
||||
else:
|
||||
if args_opt.run_distribute:
|
||||
context.set_context(enable_hccl=True)
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
|
||||
init()
|
||||
else:
|
||||
context.set_context(enable_hccl=False)
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
|
||||
|
||||
if args_opt.do_train:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
|
||||
lr_decay_mode='poly'))
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
|
||||
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Config parameters for YOLOv3 models."""
|
||||
|
||||
|
||||
class ConfigYOLOV3ResNet18:
|
||||
"""
|
||||
Config parameters for YOLOv3.
|
||||
|
||||
Examples:
|
||||
ConfigYoloV3ResNet18.
|
||||
"""
|
||||
img_shape = [352, 640]
|
||||
feature_shape = [32, 3, 352, 640]
|
||||
num_classes = 80
|
||||
|
||||
backbone_input_shape = [64, 64, 128, 256]
|
||||
backbone_shape = [64, 128, 256, 512]
|
||||
backbone_layers = [2, 2, 2, 2]
|
||||
backbone_stride = [1, 2, 2, 2]
|
||||
|
||||
ignore_threshold = 0.5
|
||||
|
||||
anchor_scales = [(10, 13),
|
||||
(16, 30),
|
||||
(33, 23),
|
||||
(30, 61),
|
||||
(62, 45),
|
||||
(59, 119),
|
||||
(116, 90),
|
||||
(156, 198),
|
||||
(163, 326)]
|
||||
out_channel = int(len(anchor_scales) / 3 * (num_classes + 5))
|
|
@ -0,0 +1,389 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""YOLOv3 dataset"""
|
||||
from __future__ import division
|
||||
|
||||
import abc
|
||||
import io
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as P
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
iter_cnt = 0
|
||||
_NUM_BOXES = 50
|
||||
|
||||
def preprocess_fn(image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
|
||||
anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
|
||||
do_hsv = False
|
||||
max_boxes = 20
|
||||
num_classes = ConfigYOLOV3ResNet18.num_classes
|
||||
|
||||
def _rand(a=0., b=1.):
|
||||
return np.random.rand() * (b - a) + a
|
||||
|
||||
def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
|
||||
"""Get true boxes."""
|
||||
num_layers = anchors.shape[0] // 3
|
||||
anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
|
||||
true_boxes = np.array(true_boxes, dtype='float32')
|
||||
# input_shape = np.array([in_shape, in_shape], dtype='int32')
|
||||
input_shape = np.array(in_shape, dtype='int32')
|
||||
boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
|
||||
boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
|
||||
true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
|
||||
true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
|
||||
|
||||
grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
|
||||
y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
|
||||
5 + num_classes), dtype='float32') for l in range(num_layers)]
|
||||
|
||||
anchors = np.expand_dims(anchors, 0)
|
||||
anchors_max = anchors / 2.
|
||||
anchors_min = -anchors_max
|
||||
|
||||
valid_mask = boxes_wh[..., 0] >= 1
|
||||
|
||||
wh = boxes_wh[valid_mask]
|
||||
|
||||
|
||||
if len(wh) >= 1:
|
||||
wh = np.expand_dims(wh, -2)
|
||||
boxes_max = wh / 2.
|
||||
boxes_min = -boxes_max
|
||||
|
||||
intersect_min = np.maximum(boxes_min, anchors_min)
|
||||
intersect_max = np.minimum(boxes_max, anchors_max)
|
||||
intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
|
||||
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
||||
box_area = wh[..., 0] * wh[..., 1]
|
||||
anchor_area = anchors[..., 0] * anchors[..., 1]
|
||||
iou = intersect_area / (box_area + anchor_area - intersect_area)
|
||||
|
||||
best_anchor = np.argmax(iou, axis=-1)
|
||||
for t, n in enumerate(best_anchor):
|
||||
for l in range(num_layers):
|
||||
if n in anchor_mask[l]:
|
||||
i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')
|
||||
j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')
|
||||
k = anchor_mask[l].index(n)
|
||||
|
||||
c = true_boxes[t, 4].astype('int32')
|
||||
y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
|
||||
y_true[l][j, i, k, 4] = 1.
|
||||
y_true[l][j, i, k, 5 + c] = 1.
|
||||
|
||||
pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32)
|
||||
pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32)
|
||||
pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32)
|
||||
|
||||
mask0 = np.reshape(y_true[0][..., 4:5], [-1])
|
||||
gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
|
||||
gt_box0 = gt_box0[mask0 == 1]
|
||||
pad_gt_box0[:gt_box0.shape[0]] = gt_box0
|
||||
|
||||
mask1 = np.reshape(y_true[1][..., 4:5], [-1])
|
||||
gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
|
||||
gt_box1 = gt_box1[mask1 == 1]
|
||||
pad_gt_box1[:gt_box1.shape[0]] = gt_box1
|
||||
|
||||
mask2 = np.reshape(y_true[2][..., 4:5], [-1])
|
||||
gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
|
||||
gt_box2 = gt_box2[mask2 == 1]
|
||||
pad_gt_box2[:gt_box2.shape[0]] = gt_box2
|
||||
|
||||
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
|
||||
|
||||
def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
|
||||
"""Data augmentation function."""
|
||||
if not isinstance(image, Image.Image):
|
||||
image = Image.fromarray(image)
|
||||
|
||||
iw, ih = image.size
|
||||
ori_image_shape = np.array([ih, iw], np.int32)
|
||||
h, w = image_size
|
||||
|
||||
if not is_training:
|
||||
image = image.resize((w, h), Image.BICUBIC)
|
||||
image_data = np.array(image) / 255.
|
||||
if len(image_data.shape) == 2:
|
||||
image_data = np.expand_dims(image_data, axis=-1)
|
||||
image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
|
||||
image_data = image_data.astype(np.float32)
|
||||
|
||||
# correct boxes
|
||||
box_data = np.zeros((max_boxes, 5))
|
||||
if len(box) >= 1:
|
||||
np.random.shuffle(box)
|
||||
if len(box) > max_boxes:
|
||||
box = box[:max_boxes]
|
||||
# xmin ymin xmax ymax
|
||||
box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw)
|
||||
box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih)
|
||||
box_data[:len(box)] = box
|
||||
else:
|
||||
image_data, box_data = None, None
|
||||
|
||||
# preprocess bounding boxes
|
||||
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
|
||||
_preprocess_true_boxes(box_data, anchors, image_size)
|
||||
|
||||
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
|
||||
ori_image_shape, gt_box1, gt_box2, gt_box3
|
||||
|
||||
flip = _rand() < .5
|
||||
# correct boxes
|
||||
box_data = np.zeros((max_boxes, 5))
|
||||
while True:
|
||||
# Prevent the situation that all boxes are eliminated
|
||||
new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
|
||||
_rand(1 - jitter, 1 + jitter)
|
||||
scale = _rand(0.25, 2)
|
||||
|
||||
if new_ar < 1:
|
||||
nh = int(scale * h)
|
||||
nw = int(nh * new_ar)
|
||||
else:
|
||||
nw = int(scale * w)
|
||||
nh = int(nw / new_ar)
|
||||
|
||||
dx = int(_rand(0, w - nw))
|
||||
dy = int(_rand(0, h - nh))
|
||||
|
||||
if len(box) >= 1:
|
||||
t_box = box.copy()
|
||||
np.random.shuffle(t_box)
|
||||
t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx
|
||||
t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy
|
||||
if flip:
|
||||
t_box[:, [0, 2]] = w - t_box[:, [2, 0]]
|
||||
t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
|
||||
t_box[:, 2][t_box[:, 2] > w] = w
|
||||
t_box[:, 3][t_box[:, 3] > h] = h
|
||||
box_w = t_box[:, 2] - t_box[:, 0]
|
||||
box_h = t_box[:, 3] - t_box[:, 1]
|
||||
t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
|
||||
|
||||
if len(t_box) >= 1:
|
||||
box = t_box
|
||||
break
|
||||
|
||||
box_data[:len(box)] = box
|
||||
# resize image
|
||||
image = image.resize((nw, nh), Image.BICUBIC)
|
||||
# place image
|
||||
new_image = Image.new('RGB', (w, h), (128, 128, 128))
|
||||
new_image.paste(image, (dx, dy))
|
||||
image = new_image
|
||||
|
||||
# flip image or not
|
||||
if flip:
|
||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# convert image to gray or not
|
||||
gray = _rand() < .25
|
||||
if gray:
|
||||
image = image.convert('L').convert('RGB')
|
||||
|
||||
# when the channels of image is 1
|
||||
image = np.array(image)
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
# distort image
|
||||
hue = _rand(-hue, hue)
|
||||
sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
|
||||
val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
|
||||
image_data = image / 255.
|
||||
if do_hsv:
|
||||
x = rgb_to_hsv(image_data)
|
||||
x[..., 0] += hue
|
||||
x[..., 0][x[..., 0] > 1] -= 1
|
||||
x[..., 0][x[..., 0] < 0] += 1
|
||||
x[..., 1] *= sat
|
||||
x[..., 2] *= val
|
||||
x[x > 1] = 1
|
||||
x[x < 0] = 0
|
||||
image_data = hsv_to_rgb(x) # numpy array, 0 to 1
|
||||
image_data = image_data.astype(np.float32)
|
||||
|
||||
# preprocess bounding boxes
|
||||
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
|
||||
_preprocess_true_boxes(box_data, anchors, image_size)
|
||||
|
||||
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
|
||||
ori_image_shape, gt_box1, gt_box2, gt_box3
|
||||
|
||||
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
|
||||
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Annotation parser."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
|
||||
def expand_path(path):
|
||||
"""Get file list from path."""
|
||||
files = []
|
||||
if os.path.isdir(path):
|
||||
for file in os.listdir(path):
|
||||
if os.path.isfile(os.path.join(path, file)):
|
||||
files.append(file)
|
||||
else:
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
return files
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Read image with PIL."""
|
||||
with open(img_path, "rb") as f:
|
||||
img = f.read()
|
||||
data = io.BytesIO(img)
|
||||
img = Image.open(data)
|
||||
return np.array(img)
|
||||
|
||||
|
||||
class BaseDataset():
|
||||
"""BaseDataset for GeneratorDataset iterator."""
|
||||
def __init__(self, image_dir, anno_path):
|
||||
self.image_dir = image_dir
|
||||
self.anno_path = anno_path
|
||||
self.cur_index = 0
|
||||
self.samples = []
|
||||
self.image_anno_dict = {}
|
||||
self._load_samples()
|
||||
|
||||
def __getitem__(self, item):
|
||||
sample = self.samples[item]
|
||||
return self._next_data(sample, self.image_dir, self.image_anno_dict)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
@staticmethod
|
||||
def _next_data(sample, image_dir, image_anno_dict):
|
||||
"""Get next data."""
|
||||
image = read_image(os.path.join(image_dir, sample))
|
||||
annos = image_anno_dict[sample]
|
||||
return [np.array(image), np.array(annos)]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_samples(self):
|
||||
"""Base load samples."""
|
||||
|
||||
|
||||
class YoloDataset(BaseDataset):
|
||||
"""YoloDataset for GeneratorDataset iterator."""
|
||||
def _load_samples(self):
|
||||
"""Load samples."""
|
||||
image_files_raw = expand_path(self.image_dir)
|
||||
self.samples = self._filter_valid_data(self.anno_path, image_files_raw)
|
||||
self.dataset_size = len(self.samples)
|
||||
if self.dataset_size == 0:
|
||||
raise RuntimeError("Valid dataset is none!")
|
||||
|
||||
def _filter_valid_data(self, anno_path, image_files_raw):
|
||||
"""Filter valid data."""
|
||||
image_files = []
|
||||
anno_dict = {}
|
||||
print("Start filter valid data.")
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8")
|
||||
line_split = str(line_str).split(' ')
|
||||
anno_dict[line_split[0].split("/")[-1]] = line_split[1:]
|
||||
anno_set = set(anno_dict.keys())
|
||||
image_set = set(image_files_raw)
|
||||
for image_file in (anno_set & image_set):
|
||||
image_files.append(image_file)
|
||||
self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file])
|
||||
image_files.sort()
|
||||
print("Filter valid data done!")
|
||||
return image_files
|
||||
|
||||
|
||||
class DistributedSampler():
|
||||
"""DistributedSampler for YOLOv3"""
|
||||
def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank % num_replicas
|
||||
self.epoch = 0
|
||||
self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas)))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=8):
|
||||
"""Creatr YOLOv3 dataset with GeneratorDataset."""
|
||||
yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
|
||||
distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
|
||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
|
||||
ds.set_dataset_size(len(distributed_sampler))
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
hwc_to_chw = P.HWC2CHW()
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.shuffle(buffer_size=256)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json"
|
||||
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
IMAGE_DIR=$3
|
||||
ANNO_PATH=$4
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$5
|
||||
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../train.py \
|
||||
--distribute=1 \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--image_dir=$IMAGE_DIR \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--anno_path=$ANNO_PATH > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt"
|
||||
|
||||
python train.py --device_id=$1 --epoch_size=$2 --image_dir=$3 --anno_path=$4
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
######################## train YOLOv3 example ########################
|
||||
train YOLOv3 and get network model files(.ckpt) :
|
||||
python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||
from dataset import create_yolo_dataset
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
||||
"""Set learning rate"""
|
||||
lr_each_step = []
|
||||
lr = learning_rate
|
||||
for i in range(global_step):
|
||||
if steps:
|
||||
lr_each_step.append(lr * (decay_rate ** (i // decay_step)))
|
||||
else:
|
||||
lr_each_step.append(lr * (decay_rate ** (i / decay_step)))
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
lr_each_step = lr_each_step[start_step:]
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def init_net_param(net, init='ones'):
|
||||
"""Init the parameters in net."""
|
||||
params = net.trainable_params()
|
||||
for p in params:
|
||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="YOLOv3")
|
||||
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--mode", type=str, default="graph", help="Run graph mode or feed mode, default is graph")
|
||||
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--image_dir", type=str, required=True, help="Dataset image dir.")
|
||||
parser.add_argument("--anno_path", type=str, required=True, help="Dataset anno path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
|
||||
if args_opt.distribute:
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_context(enable_hccl=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
rank = args_opt.device_id
|
||||
else:
|
||||
context.set_context(enable_hccl=False)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
dataset = create_yolo_dataset(args_opt.image_dir, args_opt.anno_path, repeat_num=args_opt.epoch_size,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
|
||||
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
|
||||
init_net_param(net, "XavierUniform")
|
||||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
|
||||
if args_opt.checkpoint_path != "":
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
|
||||
decay_step=1000, decay_rate=0.95))
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "graph":
|
||||
dataset_sink_mode = True
|
||||
print("Start train YOLOv3.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 5f763679fa33de1608d07f7651c6f16012b953ea
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MindSpore package."""
|
||||
|
||||
from . import common, train
|
||||
from .common import *
|
||||
from .ops import _op_impl
|
||||
from .train import *
|
||||
from .log import *
|
||||
from .version import __version__
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(__version__)
|
||||
__all__.extend(common.__all__)
|
||||
__all__.extend(train.__all__)
|
||||
__all__.extend(log.__all__)
|
|
@ -0,0 +1,542 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Check parameters."""
|
||||
import re
|
||||
from enum import Enum
|
||||
from itertools import repeat
|
||||
from collections import Iterable
|
||||
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from .common import dtype as mstype
|
||||
|
||||
|
||||
# Named string regular expression
|
||||
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
|
||||
|
||||
|
||||
class Rel(Enum):
|
||||
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
|
||||
# scalar compare
|
||||
EQ = 1 # ==
|
||||
NE = 2 # !=
|
||||
LT = 3 # <
|
||||
LE = 4 # <=
|
||||
GT = 5 # >
|
||||
GE = 6 # >=
|
||||
# scalar range check
|
||||
INC_NEITHER = 7 # (), include neither
|
||||
INC_LEFT = 8 # [), include left
|
||||
INC_RIGHT = 9 # (], include right
|
||||
INC_BOTH = 10 # [], include both
|
||||
# collection in, not in
|
||||
IN = 11
|
||||
NOT_IN = 12
|
||||
|
||||
@staticmethod
|
||||
def get_strs(rel):
|
||||
"""Get value from rel_strs."""
|
||||
return rel_strs.get(rel, "")
|
||||
|
||||
@staticmethod
|
||||
def get_fns(rel):
|
||||
"""Get value from rel_fns."""
|
||||
return rel_fns.get(rel, lambda *args: False)
|
||||
|
||||
|
||||
rel_fns = {
|
||||
# scalar compare
|
||||
Rel.EQ: lambda x, y: x == y,
|
||||
Rel.NE: lambda x, y: x != y,
|
||||
Rel.LT: lambda x, y: x < y,
|
||||
Rel.LE: lambda x, y: x <= y,
|
||||
Rel.GT: lambda x, y: x > y,
|
||||
Rel.GE: lambda x, y: x >= y,
|
||||
# scalar range check
|
||||
Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
|
||||
Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
|
||||
Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
|
||||
Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
|
||||
# collection in, not in
|
||||
Rel.IN: lambda x, y: x in y,
|
||||
Rel.NOT_IN: lambda x, y: x not in y,
|
||||
}
|
||||
|
||||
rel_strs = {
|
||||
# scalar compare
|
||||
Rel.EQ: "equal to {}",
|
||||
Rel.NE: "not equal to {}",
|
||||
Rel.LT: "less than {}",
|
||||
Rel.LE: "less or equal to {}",
|
||||
Rel.GT: "greater than {}",
|
||||
Rel.GE: "greater or equal to {}",
|
||||
# scalar range check
|
||||
Rel.INC_NEITHER: "({}, {})",
|
||||
Rel.INC_LEFT: "[{}, {})",
|
||||
Rel.INC_RIGHT: "({}, {}]",
|
||||
Rel.INC_BOTH: "[{}, {}]",
|
||||
# collection in, not in
|
||||
Rel.IN: "in {}",
|
||||
Rel.NOT_IN: "not in {}",
|
||||
}
|
||||
|
||||
|
||||
class ParamValidator:
|
||||
"""Parameter validator."""
|
||||
|
||||
@staticmethod
|
||||
def equal(arg_name, arg_value, cond_str, cond):
|
||||
"""Judging valid value."""
|
||||
if not cond:
|
||||
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
||||
"""This method is only used for check int values, since when compare float values,
|
||||
we need consider float error."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_shape_length(arg_name, arg_value, value, rel):
|
||||
"""Shape length judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
||||
"""This method is only used for check int values,
|
||||
since when compare float values, we need consider float error."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int)
|
||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_isinstance(arg_name, arg_value, classes):
|
||||
"""Check arg isintance of classes"""
|
||||
if not isinstance(arg_value, classes):
|
||||
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
||||
"""Is it necessary to consider error when comparing float values."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
||||
"""Check whether some type is sublcass of another type"""
|
||||
if not isinstance(template_type, Iterable):
|
||||
template_type = (template_type,)
|
||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
||||
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
|
||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
||||
|
||||
@staticmethod
|
||||
def check_args_tensor(args):
|
||||
"""Check whether args are all tensor."""
|
||||
if not isinstance(args, dict):
|
||||
raise TypeError("The args should be a dict.")
|
||||
for arg, value in args.items():
|
||||
ParamValidator.check_subclass(arg, value, mstype.tensor)
|
||||
|
||||
@staticmethod
|
||||
def check_type(arg_name, arg_value, valid_types):
|
||||
"""Type checking."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
||||
# `check_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
@staticmethod
|
||||
def check_typename(arg_name, arg_type, valid_types):
|
||||
"""Does it contain the _name_ attribute."""
|
||||
|
||||
def get_typename(t):
|
||||
return t.__name__ if hasattr(t, '__name__') else str(t)
|
||||
|
||||
if isinstance(arg_type, type(mstype.tensor)):
|
||||
arg_type = arg_type.element_type()
|
||||
|
||||
if arg_type in valid_types:
|
||||
return arg_type
|
||||
type_names = [get_typename(t) for t in valid_types]
|
||||
if len(valid_types) == 1:
|
||||
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
|
||||
@staticmethod
|
||||
def check_string(arg_name, arg_value, valid_values):
|
||||
"""String type judgment."""
|
||||
if isinstance(arg_value, str) and arg_value in valid_values:
|
||||
return arg_value
|
||||
if len(valid_values) == 1:
|
||||
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
|
||||
f' but got {arg_value}.')
|
||||
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
|
||||
f' but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_type_same(args, valid_values):
|
||||
"""Determine whether the types are the same."""
|
||||
name = list(args.keys())[0]
|
||||
value = list(args.values())[0]
|
||||
if isinstance(value, type(mstype.tensor)):
|
||||
value = value.element_type()
|
||||
for arg_name, arg_value in args.items():
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
|
||||
if arg_value not in valid_values:
|
||||
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
|
||||
f' but `{arg_name}` is {arg_value}.')
|
||||
if arg_value != value:
|
||||
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
|
||||
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
|
||||
"""Determine whether the types of two variables are the same."""
|
||||
if arg1_type != arg2_type:
|
||||
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
|
||||
|
||||
@staticmethod
|
||||
def check_value_on_integer(arg_name, arg_value, value, rel):
|
||||
"""Judging integer type."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_match = isinstance(arg_value, int)
|
||||
if type_match and (not rel_fn(arg_value, value)):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
|
||||
"""Judging the equality of parameters."""
|
||||
if param1_value != param2_value:
|
||||
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
|
||||
f" but got `{param1_name}` = {param1_value},"
|
||||
f" `{param2_name}` = {param2_value}.")
|
||||
|
||||
@staticmethod
|
||||
def check_const_input(arg_name, arg_value):
|
||||
"""Check valid value."""
|
||||
if arg_value is None:
|
||||
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_float_positive(arg_name, arg_value):
|
||||
"""Float type judgment."""
|
||||
if isinstance(arg_value, float):
|
||||
if arg_value > 0:
|
||||
return arg_value
|
||||
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
|
||||
|
||||
raise TypeError(f"`{arg_name}` must be float!")
|
||||
|
||||
@staticmethod
|
||||
def check_pad_value_by_mode(op_name, pad_mode, padding):
|
||||
"""Validate value of padding according to pad_mode"""
|
||||
if pad_mode != 'pad' and padding != 0:
|
||||
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
return padding
|
||||
|
||||
@staticmethod
|
||||
def check_empty_shape_input(arg_name, arg_value):
|
||||
"""Check zeros value."""
|
||||
if 0 in arg_value:
|
||||
raise ValueError(f"Input `{arg_name}` cannot be empty.")
|
||||
|
||||
@staticmethod
|
||||
def check_scalar_shape_input(arg_name, arg_value):
|
||||
"""Check scalar shape input."""
|
||||
if arg_value != []:
|
||||
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
|
||||
|
||||
|
||||
def check_int(input_param):
|
||||
"""Int type judgment."""
|
||||
if isinstance(input_param, int) and not isinstance(input_param, bool):
|
||||
return input_param
|
||||
raise TypeError("Input type must be int!")
|
||||
|
||||
|
||||
def check_int_positive(input_param):
|
||||
"""Int type judgment."""
|
||||
if isinstance(input_param, bool):
|
||||
raise TypeError("Input type must be int cannot be bool!")
|
||||
if isinstance(input_param, int):
|
||||
if input_param > 0:
|
||||
return input_param
|
||||
raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param))
|
||||
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
|
||||
|
||||
|
||||
def check_int_non_negative(input_param):
|
||||
"""Non_negative type judgment."""
|
||||
if isinstance(input_param, bool):
|
||||
raise TypeError("Input type must be int cannot be bool!")
|
||||
if isinstance(input_param, int):
|
||||
if input_param >= 0:
|
||||
return input_param
|
||||
raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param))
|
||||
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
|
||||
|
||||
|
||||
def check_int_zero_one(input_param):
|
||||
"""Judge whether it is 0 or 1."""
|
||||
if input_param in (0, 1):
|
||||
return input_param
|
||||
raise ValueError("The data must be 0 or 1.")
|
||||
|
||||
|
||||
def check_bool(input_param):
|
||||
"""Bool type judgment."""
|
||||
if isinstance(input_param, bool):
|
||||
return input_param
|
||||
raise TypeError("Input type must be bool!")
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
if input_param == "NCHW":
|
||||
return input_param
|
||||
raise ValueError("The data format must be NCHW.")
|
||||
|
||||
|
||||
def check_padding(padding):
|
||||
"""Check padding."""
|
||||
if padding >= 0:
|
||||
return padding
|
||||
raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
|
||||
|
||||
|
||||
def check_padmode(mode):
|
||||
"""Check padmode."""
|
||||
if mode in ("same", "valid", "pad"):
|
||||
return mode
|
||||
raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
|
||||
|
||||
|
||||
def check_tensor_supported_type(dtype):
|
||||
"""Check tensor dtype."""
|
||||
if dtype in (mstype.int32, mstype.float32):
|
||||
return dtype
|
||||
raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
|
||||
|
||||
|
||||
def _expand_tuple(n_dimensions):
|
||||
"""To expand a number to tuple."""
|
||||
|
||||
def convert(m):
|
||||
if not isinstance(m, tuple):
|
||||
if isinstance(m, int):
|
||||
return tuple(repeat(m, n_dimensions))
|
||||
raise TypeError("Input type must be int or tuple.")
|
||||
|
||||
if not len(m) is n_dimensions:
|
||||
raise TypeError("Input dimension is incorrect.")
|
||||
|
||||
for i in m:
|
||||
if not isinstance(i, int):
|
||||
raise TypeError("Incorrect type inside of a tuple!")
|
||||
return m
|
||||
|
||||
return convert
|
||||
|
||||
|
||||
def check_input_data(*data, data_class):
|
||||
"""Input data check."""
|
||||
for item in data:
|
||||
if isinstance(item, (list, tuple)):
|
||||
for v in item:
|
||||
check_input_data(v, data_class=data_class)
|
||||
else:
|
||||
if not isinstance(item, data_class):
|
||||
raise ValueError(f'Please provide as model inputs'
|
||||
f' either a single'
|
||||
f' or a list of {data_class.__name__},'
|
||||
f' but got part data type is {str(type(item))}.')
|
||||
if item.size() == 0:
|
||||
msg = "Please provide non-empty data."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def check_output_data(data):
|
||||
"""Output data check."""
|
||||
if not data:
|
||||
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
|
||||
|
||||
|
||||
def check_axis_type_int(axis):
|
||||
"""Check axis type."""
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError('Wrong type for axis, should be int.')
|
||||
|
||||
|
||||
def check_axis_range(axis, rank):
|
||||
"""Check axis range."""
|
||||
if not -rank <= axis < rank:
|
||||
raise ValueError('The axis should be in range [{}, {}),'' but got {}.'.format(-rank, rank, axis))
|
||||
|
||||
|
||||
def check_attr_int(attr_name, attr):
|
||||
"""Check int type."""
|
||||
if not isinstance(attr, int):
|
||||
raise TypeError("The attr {} should be int, but got {}.".format(attr_name, type(attr)))
|
||||
|
||||
|
||||
def check_t_in_range(t):
|
||||
"""Check input range."""
|
||||
if t not in (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64):
|
||||
raise ValueError("The param T should be (float16, float32, float64, int32, int64).")
|
||||
|
||||
|
||||
once = _expand_tuple(1)
|
||||
twice = _expand_tuple(2)
|
||||
triple = _expand_tuple(3)
|
||||
valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, bool, np.bool_)
|
||||
|
||||
|
||||
def check_type(arg_name, arg_value, valid_types):
|
||||
"""Check value type."""
|
||||
# if input type is Tensor ,get element type
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
|
||||
# First, check if arg_value has argvalid_types
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return type(arg_value).__name__
|
||||
|
||||
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
|
||||
if isinstance(arg_value, (list, tuple)):
|
||||
arg_value = np.array(arg_value)
|
||||
|
||||
# Thirdly, check the data type by numpy's dtype api
|
||||
valid = False
|
||||
if isinstance(arg_value, np.ndarray):
|
||||
valid = arg_value.dtype in valid_data_types
|
||||
|
||||
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
||||
# `check_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
if len(valid_types) == 1:
|
||||
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
|
||||
f' but got {type(arg_value).__name__}.')
|
||||
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
|
||||
f' but got {type(arg_value).__name__}.')
|
||||
|
||||
return type(arg_value).__name__
|
||||
|
||||
|
||||
def check_typename(arg_name, arg_type, valid_types):
|
||||
"""Check type name."""
|
||||
|
||||
def get_typename(t):
|
||||
return t.__name__ if hasattr(t, '__name__') else str(t)
|
||||
|
||||
if isinstance(arg_type, type(mstype.tensor)):
|
||||
arg_type = arg_type.element_type()
|
||||
|
||||
if arg_type in valid_types:
|
||||
return arg_type
|
||||
if isinstance(arg_type, tuple(valid_types)):
|
||||
return arg_type
|
||||
type_names = [get_typename(t) for t in valid_types]
|
||||
if len(valid_types) == 1:
|
||||
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
|
||||
|
||||
def check_shape(arg_name, arg_value):
|
||||
"""Check shape."""
|
||||
# First, check if shape is a tuple
|
||||
if not isinstance(arg_value, tuple):
|
||||
raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
|
||||
f' but got {type(arg_value).__name__}.')
|
||||
|
||||
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
|
||||
arg_value = np.array(arg_value)
|
||||
|
||||
# shape can not be ()
|
||||
if arg_value.size == 0:
|
||||
raise ValueError('Shape can not be empty.')
|
||||
|
||||
# shape's dimension should be 1
|
||||
if arg_value.ndim != 1:
|
||||
raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
|
||||
|
||||
# Thirdly, check each element's type of the shape
|
||||
valid_types = (int, np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64)
|
||||
for dim_size in arg_value:
|
||||
if not isinstance(dim_size, valid_types) or dim_size <= 0:
|
||||
raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
|
||||
' but got {}.'.format(dim_size))
|
||||
|
||||
|
||||
def _check_str_by_regular(target, reg=None, flag=re.ASCII):
|
||||
if reg is None:
|
||||
reg = _name_re
|
||||
if re.match(reg, target, flag) is None:
|
||||
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
|
||||
return True
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Extension functions.
|
||||
|
||||
Python functions that will be called in the c++ parts of MindSpore.
|
||||
"""
|
||||
from .utils import cell_attr_register
|
||||
|
||||
__all__ = ["cell_attr_register"]
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""builtin_operations"""
|
||||
import functools
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
||||
|
||||
def scalar_add(x, y):
|
||||
"""Implement `scalar_add`."""
|
||||
return x + y
|
||||
|
||||
|
||||
def scalar_mul(x, y):
|
||||
"""Implement `scalar_mul`."""
|
||||
return x * y
|
||||
|
||||
|
||||
def scalar_mod(x, y):
|
||||
"""Implement `scalar_mul`."""
|
||||
return x % y
|
||||
|
||||
|
||||
def scalar_sub(x, y):
|
||||
"""Implement `scalar_sub`."""
|
||||
return x - y
|
||||
|
||||
|
||||
def scalar_usub(x):
|
||||
"""Implement `scalar_usub`."""
|
||||
return -x
|
||||
|
||||
|
||||
def tuple_getitem(x, index):
|
||||
"""Implement `tuple_getitem`."""
|
||||
if isinstance(x, Tensor):
|
||||
x = x.asnumpy()
|
||||
y = x[index]
|
||||
return Tensor(y)
|
||||
return x[index]
|
||||
|
||||
|
||||
def scalar_gt(x, y):
|
||||
"""Implement `scalar_gt`."""
|
||||
return x > y
|
||||
|
||||
|
||||
def scalar_ne(x, y):
|
||||
"""Implement `scalar_ne`."""
|
||||
return x != y
|
||||
|
||||
|
||||
def scalar_eq(x, y):
|
||||
"""Implement `scalar_eq`."""
|
||||
return x == y
|
||||
|
||||
|
||||
def scalar_le(x, y):
|
||||
"""Implement `scalar_le`."""
|
||||
return x <= y
|
||||
|
||||
|
||||
def scalar_lt(x, y):
|
||||
"""Implement `scalar_lt`."""
|
||||
return x < y
|
||||
|
||||
|
||||
def identity(x):
|
||||
"""Implement `identity`."""
|
||||
return x
|
||||
|
||||
|
||||
def zeros_like_tensor(x):
|
||||
"""Implement `zeros_like_tensor`."""
|
||||
x = x.asnumpy()
|
||||
value = Tensor(np.zeros(x.shape))
|
||||
return value
|
||||
|
||||
|
||||
def switch(c, x, y):
|
||||
"""Implement `switch`."""
|
||||
return x if c else y
|
||||
|
||||
|
||||
def list_getitem(data, item):
|
||||
"""Implement `list_getitem`."""
|
||||
return data[item]
|
||||
|
||||
|
||||
def bool_not(x):
|
||||
"""Implement `bool_not`."""
|
||||
return not x
|
||||
|
||||
|
||||
def bool_and(x, y):
|
||||
"""Implement `bool_and`."""
|
||||
return x and y
|
||||
|
||||
|
||||
def bool_or(x, y):
|
||||
"""Implement `bool_or`."""
|
||||
return x or y
|
||||
|
||||
|
||||
def make_list(*xs):
|
||||
"""Implement `make_list`."""
|
||||
return list(xs)
|
||||
|
||||
|
||||
def list_len(x):
|
||||
"""Implement `list_len`."""
|
||||
return len(x)
|
||||
|
||||
|
||||
# only used in PyNative modes
|
||||
def partial(*args):
|
||||
"""Implement `partial`."""
|
||||
func = args[0].__call__
|
||||
partial_func = functools.partial(func, *args[1:])
|
||||
return partial_func
|
||||
|
||||
|
||||
# only used in PyNative modes
|
||||
def depend(value, expr):
|
||||
return value
|
||||
|
||||
|
||||
def scalar_cast(x, t):
|
||||
"""Implement scalar_cast."""
|
||||
np_type = dtype_to_nptype(t)
|
||||
value = np_type(x)
|
||||
cast_value = np.ndarray.item(value)
|
||||
return cast_value
|
||||
|
||||
|
||||
def typeof(x):
|
||||
"""Implement typeof."""
|
||||
return get_py_obj_dtype(x)
|
||||
|
||||
|
||||
def tuple_to_array(x):
|
||||
"""Implement `tuple_to_array`."""
|
||||
return Tensor(np.array(x))
|
||||
|
||||
|
||||
def stop_gradient(x):
|
||||
"""Implement `stop_gradient`."""
|
||||
return x
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Extension functions.
|
||||
|
||||
Python functions that will be called in the c++ parts of MindSpore.
|
||||
"""
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing multi process compile with json"""
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def _compiletask(platform, *jsons):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
platform: str. AKG platform or TBE platform
|
||||
*jsons: str. json str contain kernel info, suitable for json compile
|
||||
api
|
||||
|
||||
"""
|
||||
if platform == "AKG":
|
||||
p = __import__("akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
for json_item in jsons:
|
||||
res = func(json_item)
|
||||
if not res:
|
||||
raise ValueError("Compile error")
|
||||
if platform == "TBE":
|
||||
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
|
||||
for json_item in jsons:
|
||||
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Tbe compile error")
|
||||
|
||||
|
||||
def compilekernelparallel(jsons, process, waitime):
|
||||
"""
|
||||
compile kernel use multi processes
|
||||
|
||||
Parameters:
|
||||
jsons: list. json str list contain kernel info
|
||||
process: int. processes num
|
||||
waittime: int. max time the function blocked
|
||||
"""
|
||||
if not isinstance(jsons, list):
|
||||
raise ValueError("jsons must be a list")
|
||||
if not isinstance(process, int):
|
||||
raise ValueError("process must be a num")
|
||||
if not isinstance(waitime, int):
|
||||
raise ValueError("waittime must be a num")
|
||||
|
||||
jsons_akg = []
|
||||
jsons_tbe = []
|
||||
for json_ in jsons:
|
||||
j = json.loads(json_)
|
||||
if j["platform"] == "TBE":
|
||||
jsons_tbe.append(json_)
|
||||
continue
|
||||
if j["platform"] == "AKG":
|
||||
jsons_akg.append(json_)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"not support this platform {0}".format(j["platform"]))
|
||||
if jsons_akg:
|
||||
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
|
||||
else:
|
||||
process_akg = 0
|
||||
|
||||
if process_akg == 0 and jsons_akg:
|
||||
process_akg = 1
|
||||
process_tbe = process-process_akg
|
||||
if process_tbe == 0 and jsons_tbe:
|
||||
process_tbe = 1
|
||||
raise RuntimeWarning("we add a process for compile more operator")
|
||||
|
||||
args = [[] for _ in range(process_akg+process_tbe)]
|
||||
args_lens = len(args)
|
||||
for p in range(args_lens):
|
||||
if p < process_tbe:
|
||||
args[p].append("TBE")
|
||||
else:
|
||||
args[p].append("AKG")
|
||||
jsons_tbe_lens = len(jsons_tbe)
|
||||
for p in range(jsons_tbe_lens):
|
||||
args[p % process_tbe].append(jsons_tbe[p])
|
||||
jsons_akg_lens = len(jsons_akg)
|
||||
for p in range(jsons_akg_lens):
|
||||
args[process-p % process_akg-1].append(jsons_akg[p])
|
||||
for p in range(args_lens):
|
||||
args[p] = tuple(args[p])
|
||||
with Pool(processes=process) as pool:
|
||||
res = pool.starmap_async(_compiletask, args)
|
||||
res.get(timeout=waitime)
|
||||
return True
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe common"""
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class TBEException(Exception):
|
||||
"""tbe exception class"""
|
||||
|
||||
def __init__(self, err_msg):
|
||||
super().__init__(self)
|
||||
self.__error_msg = err_msg
|
||||
|
||||
def __str__(self):
|
||||
return self.__error_msg
|
||||
|
||||
|
||||
def get_ddk_version():
|
||||
"""get ddk version"""
|
||||
ddk_version = os.environ.get("DDK_VERSION")
|
||||
if ddk_version is None:
|
||||
default_ddk_info_file = '/usr/local/HiAI/runtime/ddk_info'
|
||||
backup_ddk_info_file = '/usr/local/Ascend/fwkacllib/ddk_info'
|
||||
if os.path.exists(default_ddk_info_file):
|
||||
with open(default_ddk_info_file, "r") as fp:
|
||||
ddk_version = json.load(fp)["VERSION"]
|
||||
elif os.path.exists(backup_ddk_info_file):
|
||||
with open(backup_ddk_info_file, "r") as fp:
|
||||
ddk_version = json.load(fp)["VERSION"]
|
||||
else:
|
||||
ddk_version = "1.60.T17.B830"
|
||||
return ddk_version
|
||||
|
||||
|
||||
def get_build_in_impl_path():
|
||||
"""get build-in tbe implement path"""
|
||||
tbe_impl_path = os.environ.get("TBE_IMPL_PATH")
|
||||
if tbe_impl_path is None:
|
||||
default_install_path = '/usr/local/HiAI/runtime/ops/op_impl/built-in/ai_core/tbe/'
|
||||
backup_install_path = '/usr/local/Ascend/Ascend/opp/op_impl/built-in/ai_core/tbe/'
|
||||
if os.path.exists(default_install_path):
|
||||
tbe_impl_path = default_install_path
|
||||
elif os.path.exists(backup_install_path):
|
||||
tbe_impl_path = backup_install_path
|
||||
if not tbe_impl_path:
|
||||
raise ValueError("Can not find the env TBE_IMPL_PATH")
|
||||
return tbe_impl_path
|
||||
|
||||
|
||||
def _check_arg_info(item):
|
||||
"""
|
||||
Check parameter Validity.
|
||||
|
||||
Args:
|
||||
item (dict): A dict, to be checked.
|
||||
|
||||
Raises:
|
||||
Exception: If specific keyword is not found.
|
||||
"""
|
||||
if 'shape' not in item:
|
||||
raise ValueError("Json string Errors, key:shape not found.")
|
||||
if 'ori_shape' not in item:
|
||||
raise ValueError("Json string Errors, key:ori_shape not found.")
|
||||
if 'format' not in item or not item['format']:
|
||||
raise ValueError("Json string Errors, key:format not found.")
|
||||
if 'ori_format' not in item or not item['ori_format']:
|
||||
raise ValueError("Json string Errors, key:ori_format not found.")
|
||||
if 'dtype' not in item or not item['dtype']:
|
||||
raise ValueError("Json string Errors, key:dtype not found.")
|
||||
|
||||
|
||||
def get_args(op_info, arg_type):
|
||||
"""
|
||||
Parse args.
|
||||
|
||||
Args:
|
||||
op_info (dict): Op info dict.
|
||||
arg_type (str): arg, to be parsed.
|
||||
|
||||
Raises:
|
||||
Exception: If specific keyword is not found.
|
||||
"""
|
||||
if arg_type not in op_info:
|
||||
raise ValueError("Json string Errors, key:{} not found.".format(arg_type))
|
||||
args = []
|
||||
if not op_info[arg_type]:
|
||||
return args
|
||||
if arg_type in ['inputs', 'outputs']:
|
||||
for item in op_info[arg_type]:
|
||||
arg = []
|
||||
for info in item:
|
||||
if 'valid' not in info:
|
||||
raise ValueError("Json string Errors, key:valid not found.")
|
||||
if info['valid']:
|
||||
_check_arg_info(info)
|
||||
del info['valid']
|
||||
del info['name']
|
||||
if len(item) > 1:
|
||||
arg.append(info)
|
||||
else:
|
||||
args.append(info)
|
||||
else:
|
||||
if len(item) > 1:
|
||||
arg.append(None)
|
||||
else:
|
||||
args.append(None)
|
||||
if len(item) > 1:
|
||||
args.append(arg)
|
||||
|
||||
elif arg_type == 'attrs':
|
||||
for item in op_info[arg_type]:
|
||||
if 'value' not in item:
|
||||
raise ValueError("Json string Errors, attr key:value not found.")
|
||||
if item["name"] != "isRef":
|
||||
args.append(item['value'])
|
||||
return args
|
||||
|
||||
|
||||
def check_kernel_info(kernel_info):
|
||||
if 'op_info' not in kernel_info or not kernel_info['op_info']:
|
||||
raise ValueError("Json string Errors, key:op_info not found.")
|
||||
if 'name' not in kernel_info['op_info'] or not kernel_info['op_info']['name']:
|
||||
raise ValueError("Json string Errors, key:name not found.")
|
||||
if 'kernel_name' not in kernel_info['op_info'] or not kernel_info['op_info']['kernel_name']:
|
||||
raise ValueError("Json string Errors, key:kernel_name not found.")
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe compiler"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from te.platform.cce_conf import te_set_version
|
||||
from te.platform.fusion_manager import op_build_cfg_dis, op_build_cfg_en, set_current_op_name, \
|
||||
init_op_pattern, set_op_params, set_op_build_type, get_op_pattern, set_current_op_func_name
|
||||
from te.platform.fusion_util import fusion_op
|
||||
from common import check_kernel_info, get_args, get_build_in_impl_path, get_ddk_version
|
||||
|
||||
ddk_version = get_ddk_version()
|
||||
build_in_impl_path = get_build_in_impl_path()
|
||||
|
||||
# op function list
|
||||
op_build = "compile"
|
||||
op_pre_build = "pre_build"
|
||||
|
||||
|
||||
def _initialize(impl_path):
|
||||
"""Initialize"""
|
||||
te_set_version(ddk_version)
|
||||
if impl_path == "":
|
||||
op_module_name = build_in_impl_path
|
||||
else:
|
||||
op_module_name = impl_path
|
||||
if not op_module_name:
|
||||
raise ValueError("Can not find the env TBE_IMPL_PATH")
|
||||
|
||||
sys.path.insert(0, op_module_name)
|
||||
|
||||
|
||||
def build_op(build_type, json_str):
|
||||
"""
|
||||
call op functions with function name and input args json_str
|
||||
|
||||
Args:
|
||||
build_type : op function name
|
||||
json_str (str): op function input args
|
||||
|
||||
Raises:
|
||||
Exception: If specific keyword is not found.
|
||||
"""
|
||||
kernel_info = json.loads(json_str)
|
||||
check_kernel_info(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
|
||||
try:
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(impl_path):
|
||||
path, file_name = os.path.split(impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
else:
|
||||
impl_path = ""
|
||||
_initialize(impl_path)
|
||||
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if build_type == op_pre_build:
|
||||
# set op parameter
|
||||
op_build_cfg_dis()
|
||||
set_current_op_func_name(op_name)
|
||||
set_current_op_name(kernel_name)
|
||||
init_op_pattern()
|
||||
set_op_params(*outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
set_op_build_type('prebuild')
|
||||
if custom_flag:
|
||||
py_fn_name = kernel_info['op_info']['name']
|
||||
else:
|
||||
py_fn_name = op_name
|
||||
elif build_type == op_build:
|
||||
if custom_flag:
|
||||
py_fn_name = kernel_info['op_info']['name']
|
||||
else:
|
||||
py_fn_name = op_name
|
||||
else:
|
||||
raise ValueError("function {} is not supported by Tbe op {}.".format(build_type, op_name))
|
||||
op_func = getattr(op_module, py_fn_name, None)
|
||||
if op_func is None:
|
||||
raise ValueError("Op:{} function {} is not supported by Tbe.".format(op_name, build_type))
|
||||
|
||||
# pre build
|
||||
if build_type == op_pre_build:
|
||||
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name)
|
||||
# disable only pattern configuration
|
||||
op_build_cfg_en()
|
||||
return get_op_pattern()
|
||||
|
||||
# call function
|
||||
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
|
||||
except Exception as e:
|
||||
if build_type == op_pre_build:
|
||||
op_build_cfg_en()
|
||||
raise RuntimeError(e)
|
||||
|
||||
|
||||
def compile_fusion_op(json_str):
|
||||
"""
|
||||
compile fusion op with input args json_str
|
||||
|
||||
Args:
|
||||
json_str (str): op function input args
|
||||
|
||||
Raises:
|
||||
Exception: If specific keyword is not found.
|
||||
"""
|
||||
args = json.loads(json_str)
|
||||
if 'fusion_op' not in args or not args['fusion_op']:
|
||||
raise ValueError("Json string Errors, key:fusion_op not found.")
|
||||
if 'prebuild_ops' not in args or not args['prebuild_ops']:
|
||||
raise ValueError("Json string Errors, key:prebuild_ops not found.")
|
||||
|
||||
pre_build_op_list = args['prebuild_ops']
|
||||
for op in pre_build_op_list:
|
||||
build_op(op_pre_build, json.dumps(op))
|
||||
fusion_op_arg = args['fusion_op']
|
||||
return fusion_op(json.dumps(fusion_op_arg))
|
||||
|
||||
|
||||
def compile_with_json(json_str):
|
||||
"""
|
||||
Compile tbe with json.
|
||||
|
||||
Args:
|
||||
json_str (str): jason file path.
|
||||
|
||||
"""
|
||||
json_info = json.loads(json_str)
|
||||
if "fusion_op" in json_info:
|
||||
ret = compile_fusion_op(json_str)
|
||||
else:
|
||||
ret = build_op(op_build, json_str)
|
||||
return ret
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
in_args = sys.stdin.readline()
|
||||
compile_with_json(in_args)
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe process"""
|
||||
import traceback
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from .common import check_kernel_info, get_args, get_build_in_impl_path
|
||||
|
||||
build_in_impl_path = get_build_in_impl_path()
|
||||
|
||||
|
||||
def create_tbe_parallel_compiler():
|
||||
"""
|
||||
create TBEParallelCompiler object
|
||||
|
||||
Returns:
|
||||
TBEParallelCompiler
|
||||
"""
|
||||
return compile_pool
|
||||
|
||||
def op_select_format(op_json: str):
|
||||
"""
|
||||
call op's op_select_format to get op supported format
|
||||
|
||||
Args:
|
||||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
op supported format
|
||||
"""
|
||||
ret = ""
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
sys.path.insert(0, impl_path)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "op_select_format"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "op_select_format", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
return ret
|
||||
|
||||
|
||||
def check_supported(op_json: str):
|
||||
"""
|
||||
call op's check_supported to check supported or not
|
||||
|
||||
Args:
|
||||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
true or false
|
||||
"""
|
||||
ret = ""
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
sys.path.insert(0, impl_path)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "check_supported"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "check_supported", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
return ret
|
||||
|
||||
|
||||
def run_compiler(op_json):
|
||||
"""
|
||||
run compiler to compile op with subprocess
|
||||
|
||||
Args:
|
||||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
result type, result.
|
||||
"""
|
||||
try:
|
||||
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
|
||||
subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
|
||||
text=True, capture_output=True, check=True)
|
||||
return "Success", "Success"
|
||||
except subprocess.TimeoutExpired:
|
||||
tb = traceback.format_exc()
|
||||
return "TBEException", "CompileTimeOut: " + tb + "\ninput_args: " + op_json
|
||||
except subprocess.CalledProcessError as e:
|
||||
return "TBEException", "CompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
|
||||
|
||||
|
||||
class CompilerPool:
|
||||
"""compiler pool"""
|
||||
|
||||
def __init__(self):
|
||||
processes = multiprocessing.cpu_count()
|
||||
if processes > 32:
|
||||
processes = 32
|
||||
self.__pool = multiprocessing.Pool(processes=processes)
|
||||
self.__next_task_id = 1
|
||||
self.__running_tasks = []
|
||||
|
||||
def __del__(self):
|
||||
if self.__pool is not None:
|
||||
self.__pool.terminate()
|
||||
self.__pool.join()
|
||||
del self.__pool
|
||||
|
||||
def exit(self):
|
||||
return
|
||||
# self.__pool.terminate()
|
||||
# self.__pool.join()
|
||||
# if self.__pool is not None:
|
||||
# del self.__pool
|
||||
|
||||
def start_compile_op(self, op_json):
|
||||
"""
|
||||
start compile op async.
|
||||
|
||||
Args:
|
||||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
int, task id(>0). -1 if error
|
||||
"""
|
||||
task_id = self.__next_task_id
|
||||
self.__next_task_id = self.__next_task_id + 1
|
||||
task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,))
|
||||
self.__running_tasks.append((task_id, task_future))
|
||||
return task_id
|
||||
|
||||
def wait_one(self):
|
||||
"""
|
||||
wait until a compile task finish
|
||||
|
||||
Returns:
|
||||
int, id of the finished task. -1 if error,0 if no unfinished task
|
||||
str, result of compile task
|
||||
"""
|
||||
ret = 0, "Success"
|
||||
if self.__running_tasks:
|
||||
task_id, task_future = self.__running_tasks.pop(0)
|
||||
ret_type, result = task_future.get(330)
|
||||
if ret_type == "Success":
|
||||
ret = task_id, "Success"
|
||||
elif ret_type in ("Exception", "TBEException"):
|
||||
ret = task_id, ret_type + ":" + result
|
||||
else:
|
||||
ret = task_id, "Exception: Not support return type:" + str(ret_type)
|
||||
return ret
|
||||
|
||||
|
||||
compile_pool = CompilerPool()
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Interfaces for parser module in c++.
|
||||
"""
|
||||
|
||||
from .parser import (Parser, create_obj_instance, generate_scope,
|
||||
get_bprop_method_of_class, get_class_instance_type,
|
||||
get_class_member_namespace_symbol,
|
||||
get_dataclass_attributes, get_dataclass_methods,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_parse_method_of_class, get_scope_name,
|
||||
is_class_member, parse_cb, resolve_symbol)
|
||||
from .serialize import *
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
|
||||
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
|
||||
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
|
||||
'get_dataclass_methods', 'get_scope_name']
|
|
@ -0,0 +1,110 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define the namespace of parse."""
|
||||
|
||||
import builtins
|
||||
|
||||
|
||||
class Namespace:
|
||||
"""
|
||||
Base class of namespace for resolve variables.
|
||||
|
||||
Args:
|
||||
name (str): The namespace's name.
|
||||
dicts (dict): A list of dict containing the namespace's variable.
|
||||
"""
|
||||
def __init__(self, name, *dicts):
|
||||
self.name = name
|
||||
self.dicts = dicts
|
||||
|
||||
def __contains__(self, name):
|
||||
for d in self.dicts:
|
||||
if name in d:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __getitem__(self, name):
|
||||
for d in self.dicts:
|
||||
if name in d:
|
||||
return d[name]
|
||||
raise NameError(name)
|
||||
|
||||
def __repr__(self):
|
||||
return f'Namespace:{self.name}'
|
||||
|
||||
|
||||
class CellNamespace(Namespace):
|
||||
"""
|
||||
Namespace for Cell object.
|
||||
|
||||
Args:
|
||||
name (str): Valid module name, it can be imported.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
mod_dict = vars(__import__(name, fromlist=['_']))
|
||||
builtins_dict = vars(builtins)
|
||||
super().__init__(name, mod_dict, builtins_dict)
|
||||
|
||||
def __getstate__(self):
|
||||
return (self.name,)
|
||||
|
||||
def __setstate__(self, state):
|
||||
name, = state
|
||||
mod_dict = vars(__import__(name, fromlist=['_']))
|
||||
builtins_dict = vars(builtins)
|
||||
super().__init__(name, mod_dict, builtins_dict)
|
||||
|
||||
|
||||
class ClosureNamespace(Namespace):
|
||||
"""
|
||||
Namespace for function closure.
|
||||
|
||||
Args:
|
||||
fn (Function): A python function.
|
||||
"""
|
||||
def __init__(self, fn):
|
||||
name = f'{fn.__module__}..<{fn.__name__}>'
|
||||
names = fn.__code__.co_freevars
|
||||
cells = fn.__closure__
|
||||
ns = dict(zip(names, cells or ()))
|
||||
super().__init__(name, ns)
|
||||
|
||||
def __getitem__(self, name):
|
||||
d, = self.dicts
|
||||
try:
|
||||
return d[name].cell_contents
|
||||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
||||
|
||||
|
||||
class ClassMemberNamespace(Namespace):
|
||||
"""
|
||||
Namespace of a class's closure.
|
||||
|
||||
Args:
|
||||
obj (Object): A python class object.
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
|
||||
super().__init__(label, obj)
|
||||
|
||||
def __getitem__(self, name):
|
||||
d, = self.dicts
|
||||
try:
|
||||
return getattr(d, name)
|
||||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
|
@ -0,0 +1,488 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""The module of parser python object, called by c++."""
|
||||
|
||||
import ast
|
||||
import types
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from dataclasses import is_dataclass
|
||||
import asttokens
|
||||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
from mindspore import ops
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.api import _MindSporeFunction
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
||||
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
|
||||
# define return value
|
||||
RET_SUCCESS = 0
|
||||
RET_FAILURE = 0xFF
|
||||
|
||||
# define resolve type
|
||||
RESOLVE_TYPE_NONE = 0 # resolve None
|
||||
RESOLVE_TYPE_FUNCTION = 1 # resolve function
|
||||
RESOLVE_TYPE_METHOD = 2 # resolve class method
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3 # resolve class type
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4 # resolve the class instance of common class
|
||||
RESOLVE_TYPE_INVALID = 0xFF
|
||||
|
||||
# define the class instance detail type
|
||||
# When the type is RESOLVE_TYPE_CLASS_INSTANCE
|
||||
CLASS_INSTANCE_TYPE_CELL = 0 # class instance type is Cell
|
||||
CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # class instance type is Primitive
|
||||
CLASS_INSTANCE_TYPE_INVALID = 0xFF
|
||||
|
||||
# Ast main type
|
||||
AST_MAIN_TYPE_STMT = 0 # ast.Stmt
|
||||
AST_MAIN_TYPE_EXPR = 1 # ast.Expr
|
||||
AST_MAIN_TYPE_SLICE = 2 # ast.Slice
|
||||
AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown
|
||||
|
||||
# Ast sub type
|
||||
AST_SUB_TYPE_AND = 3 # ast.And
|
||||
AST_SUB_TYPE_OR = 4 # ast.Or
|
||||
AST_SUB_TYPE_NAME = 5 # ast.Name
|
||||
AST_SUB_TYPE_TUPLE = 6 # ast.Tuple
|
||||
AST_SUB_TYPE_SUBSCRIPT = 7 # ast.Subscript
|
||||
AST_SUB_TYPE_STARRED = 8 # ast.Starred
|
||||
AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
|
||||
|
||||
# Process expr statement white list
|
||||
# add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
|
||||
parse_expr_statement_white_list = (
|
||||
"append",
|
||||
)
|
||||
|
||||
|
||||
def parse_cb(func, parse_method=None):
|
||||
"""Implements the function of parse."""
|
||||
return Parser(func, parse_method)
|
||||
|
||||
|
||||
def get_parse_method_of_class(obj, parse_method=None):
|
||||
"""
|
||||
Het parse method of class.
|
||||
|
||||
Args:
|
||||
obj(Object): Instance of class.
|
||||
parse_method(str): Save the method name. Cell object has default method named 'construct'.
|
||||
|
||||
Returns:
|
||||
Function, obj's method.
|
||||
"""
|
||||
method = None
|
||||
method_name = None
|
||||
if parse_method is not None:
|
||||
method_name = parse_method
|
||||
else:
|
||||
if isinstance(obj, nn.Cell):
|
||||
method_name = "construct"
|
||||
if method_name is not None:
|
||||
if hasattr(obj, method_name):
|
||||
method = getattr(obj, method_name)
|
||||
return method
|
||||
|
||||
|
||||
def get_bprop_method_of_class(obj, parse_method=None):
|
||||
"""
|
||||
Get bprop method of class.
|
||||
|
||||
Args:
|
||||
obj (Object): Instance of class.
|
||||
parse_method(str): Save the method name. Cell object has default method named 'bprop'.
|
||||
|
||||
Returns:
|
||||
Function, obj's method.
|
||||
"""
|
||||
method = None
|
||||
if isinstance(obj, nn.Cell):
|
||||
method_name = "bprop"
|
||||
if hasattr(obj, method_name):
|
||||
method = getattr(obj, method_name)
|
||||
return method
|
||||
|
||||
|
||||
def resolve_symbol(namespace, symbol):
|
||||
"""
|
||||
Resolve a symbol.
|
||||
|
||||
Note:
|
||||
Can't get function when use closure function. So save the fn on namespace.
|
||||
|
||||
Args:
|
||||
namespace (Object): Symbol's namespace.
|
||||
symbol (str): Need resolve symbol.
|
||||
|
||||
Returns:
|
||||
Object, resolve result of symbol.
|
||||
"""
|
||||
# All exceptions need to be caught in this function
|
||||
try:
|
||||
resolve_ = namespace[symbol]
|
||||
|
||||
# list and dict is not hashable ,it can not be key for the map, just return the result
|
||||
if isinstance(resolve_, (list, dict)):
|
||||
return resolve_
|
||||
|
||||
# dataclass may not be hashable
|
||||
if getattr(resolve_, "__hash__") is None:
|
||||
return resolve_
|
||||
|
||||
# If need trope the obj
|
||||
if resolve_ in convert_object_map:
|
||||
resolve_ = convert_object_map.get(resolve_)
|
||||
logger.debug("convert resolve = %r", resolve_)
|
||||
if resolve_ == NO_IMPLEMENT:
|
||||
raise NotImplementedError("not implemented for ", str(symbol))
|
||||
except Exception as e:
|
||||
if isinstance(e, NotImplementedError):
|
||||
raise e
|
||||
resolve_ = None
|
||||
logger.debug("resolve exception occurred, value = %r", e)
|
||||
logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
|
||||
namespace.__str__(), symbol)
|
||||
if isinstance(resolve_, _MindSporeFunction):
|
||||
logger.debug("resolve class _MindSporeFunction, resolve fn instead.")
|
||||
resolve_ = resolve_.fn
|
||||
return resolve_
|
||||
|
||||
|
||||
def generate_scope(obj):
|
||||
"""Generate the scope for every cell object in the network."""
|
||||
if isinstance(obj, nn.Cell):
|
||||
obj.generate_scope()
|
||||
|
||||
|
||||
def get_scope_name(obj):
|
||||
"""Returns the scope of a cell object in one network."""
|
||||
if isinstance(obj, nn.Cell):
|
||||
return obj.get_scope()
|
||||
return None
|
||||
|
||||
|
||||
def get_object_key(obj):
|
||||
"""Return the function key: module + name."""
|
||||
obj_key = ""
|
||||
if hasattr(obj, "__name__"):
|
||||
if hasattr(obj, "cell_init_args"):
|
||||
obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
|
||||
obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
|
||||
else:
|
||||
if hasattr(obj, "cell_init_args"):
|
||||
obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args)
|
||||
obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj))
|
||||
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
|
||||
|
||||
# method has same id of different instance
|
||||
if isinstance(obj, types.MethodType):
|
||||
method_instance = obj.__self__
|
||||
instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
|
||||
obj_id = instance_id + obj_id
|
||||
return obj_id, obj_key
|
||||
|
||||
|
||||
def is_class_member(node):
|
||||
"""Check the attr is class member variable."""
|
||||
type_ = node.__class__.__name__
|
||||
if type_ == "Attribute":
|
||||
if not hasattr(node.value, "id"):
|
||||
return False
|
||||
id_ = node.value.id
|
||||
if id_ == "self":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_obj_type(obj):
|
||||
"""Get the obj type."""
|
||||
obj_type = RESOLVE_TYPE_INVALID
|
||||
if obj is None:
|
||||
obj_type = RESOLVE_TYPE_NONE
|
||||
elif isinstance(obj, types.FunctionType):
|
||||
obj_type = RESOLVE_TYPE_FUNCTION
|
||||
elif isinstance(obj, types.MethodType):
|
||||
obj_type = RESOLVE_TYPE_METHOD
|
||||
elif isinstance(obj, type):
|
||||
obj_type = RESOLVE_TYPE_CLASS_TYPE
|
||||
elif _is_class_instance(obj):
|
||||
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
|
||||
else:
|
||||
# here for ndarray, just print its shape (in case of the array to large and print many data in screen)
|
||||
is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
|
||||
raise TypeError(f'Invalid object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
|
||||
f'`{obj.shape if is_ndarray else obj}`.')
|
||||
return obj_type
|
||||
|
||||
|
||||
def get_class_instance_type(obj):
|
||||
"""Get the class instance detail type."""
|
||||
# check the obj type
|
||||
logger.debug("Get the class type(%r)", obj)
|
||||
class_type = CLASS_INSTANCE_TYPE_INVALID
|
||||
if _is_class_instance(obj):
|
||||
if isinstance(obj, nn.Cell):
|
||||
class_type = CLASS_INSTANCE_TYPE_CELL
|
||||
elif isinstance(obj, ops.Primitive):
|
||||
class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
|
||||
# Add the other type base requirement
|
||||
return class_type
|
||||
|
||||
|
||||
def _is_class_instance(obj):
|
||||
"""Confirm the obj is class instance."""
|
||||
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
|
||||
|
||||
|
||||
def _is_dataclass_instance(obj):
|
||||
"""check whether a class is an instance of a dataclass (and not a dataclass itself)"""
|
||||
return is_dataclass(obj) and not isinstance(obj, type)
|
||||
|
||||
|
||||
def create_obj_instance(cls_type, args_tuple=None):
|
||||
"""Create python instance."""
|
||||
obj = None
|
||||
if isinstance(cls_type, type):
|
||||
# check the type, now only support nn.Cell and Primitive
|
||||
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
|
||||
if args_tuple is not None:
|
||||
obj = cls_type(*args_tuple)
|
||||
else:
|
||||
obj = cls_type()
|
||||
return obj
|
||||
|
||||
|
||||
def get_module_namespace(obj):
|
||||
"""Get the module's namespace."""
|
||||
logger.debug("get module namespace, module = %r", obj)
|
||||
mod_namespace = None
|
||||
if isinstance(obj, types.ModuleType):
|
||||
mod_namespace = CellNamespace(obj.__name__)
|
||||
else:
|
||||
logger.warning("Module(%r) is invalid, get namespace failure!", obj)
|
||||
return mod_namespace
|
||||
|
||||
|
||||
def get_class_member_namespace_symbol(obj):
|
||||
"""Get obj class member type."""
|
||||
logger.debug("get class instance namespace, object = %r", obj)
|
||||
class_namespace = ClassMemberNamespace(obj)
|
||||
logger.debug("class namesapce = %r", class_namespace)
|
||||
return class_namespace
|
||||
|
||||
|
||||
def get_dataclass_attributes(cls):
|
||||
"""Get attributes of dataclass."""
|
||||
fields = cls.__dataclass_fields__
|
||||
attributes = {name: pytype_to_dtype(field.type)
|
||||
for name, field in fields.items()}
|
||||
return attributes
|
||||
|
||||
|
||||
def get_dataclass_methods(cls):
|
||||
"""Get functions of dataclass."""
|
||||
methods = {name: getattr(cls, name)
|
||||
for name in dir(cls)
|
||||
if isinstance(getattr(cls, name), (types.FunctionType,))}
|
||||
return methods
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
||||
Args:
|
||||
fn(FunctionType/MethodType): Need parse object instance.
|
||||
parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
|
||||
"""
|
||||
def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
|
||||
self.fn = fn
|
||||
self.parse_method = parse_method
|
||||
_, self.line_offset = inspect.getsourcelines(self.fn)
|
||||
self.filename: str = inspect.getfile(self.fn)
|
||||
|
||||
# Used to resolve the function's globals Namespace.
|
||||
self.global_namespace = CellNamespace(fn.__module__)
|
||||
self.function_module = fn.__module__
|
||||
# Used to resolve the function's nonlocals.
|
||||
self.closure_namespace = ClosureNamespace(fn)
|
||||
self.function_name = fn.__name__
|
||||
self.col_offset = 0
|
||||
|
||||
def parse(self):
|
||||
"""Parse the function or method."""
|
||||
logger.debug("fn = %r", self.fn)
|
||||
tree = None
|
||||
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
|
||||
original_src = inspect.getsource(self.fn)
|
||||
src = dedent(original_src)
|
||||
self.col_offset = \
|
||||
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
||||
logger.debug("get source = %s", src)
|
||||
tree = asttokens.ASTTokens(src, parse=True).tree
|
||||
else:
|
||||
logger.error("Fn type is invalid")
|
||||
return tree
|
||||
|
||||
def get_args(self, node):
|
||||
"""Get the arg of parse object."""
|
||||
args = []
|
||||
# process position args
|
||||
for arg in node.args.args:
|
||||
args.append(arg)
|
||||
|
||||
# process kwonlyargs: kwonlyargs is append after position args
|
||||
if node.args.kwonlyargs:
|
||||
for kwarg in node.args.kwonlyargs:
|
||||
args.append(kwarg)
|
||||
# process vararg: vararg is append after kwonlyargs
|
||||
if node.args.vararg:
|
||||
args.append(node.args.vararg)
|
||||
# process kwarg: kwarg is append after vararg
|
||||
if node.args.kwarg:
|
||||
args.append(node.args.kwarg)
|
||||
return args
|
||||
|
||||
def get_args_default_values(self, node):
|
||||
"""get the args'default values of parse object."""
|
||||
nondefaults = [None] * (len(node.args.args) - len(node.args.defaults))
|
||||
defaults = nondefaults + node.args.defaults + node.args.kw_defaults
|
||||
if node.args.vararg:
|
||||
defaults.append(None)
|
||||
if node.args.kwarg:
|
||||
defaults.append(None)
|
||||
return defaults
|
||||
|
||||
def get_node_type(self, node):
|
||||
"""Process an ast node."""
|
||||
method_name = f'{node.__class__.__name__}'
|
||||
node_type = [method_name]
|
||||
# judge the ast main type
|
||||
if isinstance(node, ast.stmt):
|
||||
node_type.append(AST_MAIN_TYPE_STMT)
|
||||
elif isinstance(node, (ast.expr, ast.slice)) or node is None:
|
||||
# ast.slice and ast.expr should be expr
|
||||
node_type.append(AST_MAIN_TYPE_EXPR)
|
||||
else:
|
||||
node_type.append(AST_MAIN_TYPE_UNKNOWN)
|
||||
return node_type
|
||||
|
||||
def get_ast_type(self, node):
|
||||
"""Get the ast type."""
|
||||
ast_type = AST_SUB_TYPE_UNKNOWN
|
||||
if isinstance(node, ast.And):
|
||||
ast_type = AST_SUB_TYPE_AND
|
||||
elif isinstance(node, ast.Or):
|
||||
ast_type = AST_SUB_TYPE_OR
|
||||
elif isinstance(node, ast.Name):
|
||||
ast_type = AST_SUB_TYPE_NAME
|
||||
elif isinstance(node, ast.Tuple):
|
||||
ast_type = AST_SUB_TYPE_TUPLE
|
||||
elif isinstance(node, ast.Subscript):
|
||||
ast_type = AST_SUB_TYPE_SUBSCRIPT
|
||||
elif isinstance(node, ast.Starred):
|
||||
ast_type = AST_SUB_TYPE_STARRED
|
||||
else:
|
||||
ast_type = AST_SUB_TYPE_UNKNOWN
|
||||
return ast_type
|
||||
|
||||
def get_namespace_symbol(self, var: str):
|
||||
"""Get symbol type and namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
ops_info = (self.closure_namespace, var)
|
||||
logger.debug("in closure_namespace")
|
||||
elif var in self.global_namespace:
|
||||
ops_info = (self.global_namespace, var)
|
||||
logger.debug("in global_namespace")
|
||||
else:
|
||||
ops_info = parse_object_map.get(SYMBOL_UNDEFINE)
|
||||
ops_info = [ops_info[0], var]
|
||||
return ops_info
|
||||
|
||||
def get_operation_namespace_symbol(self, var: str):
|
||||
"""Get operation namespace and symbol."""
|
||||
ops_info = (trope_ns, var)
|
||||
logger.debug("get operation ops info = %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
def get_ast_namespace_symbol(self, obj):
|
||||
"""Get obj type and namespace and symbol."""
|
||||
# step 1:get symbol from object map
|
||||
ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
|
||||
logger.debug("ops info = %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
def get_location(self, node):
|
||||
"""
|
||||
Get location of node start and end line no.
|
||||
|
||||
Args:
|
||||
node: AST op node or tuple or List. This is a node in the ANF diagram,
|
||||
here is the code location to get this node.
|
||||
|
||||
Returns:
|
||||
List, [fileName, linestart, colstart, lineend, colend].
|
||||
"""
|
||||
ret = [self.filename]
|
||||
err_exit = 0
|
||||
if isinstance(node, (list, tuple)):
|
||||
node_size = len(node)
|
||||
if node_size == 0:
|
||||
err_exit = 1
|
||||
else:
|
||||
start_node = node[0]
|
||||
end_node = node[-1]
|
||||
else:
|
||||
start_node = node
|
||||
end_node = node
|
||||
|
||||
if err_exit == 0:
|
||||
if hasattr(start_node, "lineno") and \
|
||||
hasattr(end_node, "col_offset"):
|
||||
start_lineno, start_colno = start_node.first_token.start
|
||||
end_lineno, end_colno = end_node.last_token.end
|
||||
start_lineno += self.line_offset - 1
|
||||
start_colno += self.col_offset
|
||||
end_lineno += self.line_offset - 1
|
||||
end_colno += self.col_offset
|
||||
ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
|
||||
else:
|
||||
ret = ret + [0, 0, 0, 0]
|
||||
return ret
|
||||
|
||||
def expand_expr_statement(self, node):
|
||||
"""
|
||||
Process the expr statement and expand it.
|
||||
|
||||
Returns:
|
||||
tuple, (True, expr.value, x)/(False, None, None).
|
||||
"""
|
||||
if isinstance(node, ast.Expr) and hasattr(node, "value"):
|
||||
expr_value = node.value
|
||||
if isinstance(expr_value, ast.Call):
|
||||
func = expr_value.func
|
||||
if isinstance(func, ast.Attribute) and \
|
||||
hasattr(func, "attr") and \
|
||||
hasattr(func, "value"):
|
||||
method = func.attr
|
||||
target = func.value
|
||||
if method in parse_expr_statement_white_list:
|
||||
logger.debug("Expand expr, target:%s, method:%s", target, method)
|
||||
return True, expr_value, target
|
||||
return True, expr_value
|
||||
return False, None, None
|
|
@ -0,0 +1,137 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Resources for ast tree parse."""
|
||||
import ast
|
||||
import math
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from . import standard_method as M
|
||||
from . import trope as T
|
||||
from .namespace import CellNamespace
|
||||
|
||||
|
||||
# namespace define
|
||||
functional_ns = CellNamespace('mindspore.ops.functional')
|
||||
composite_ns = CellNamespace('mindspore.ops.composite')
|
||||
trope_ns = CellNamespace('mindspore._extends.parse.trope')
|
||||
|
||||
NO_IMPLEMENT = None # not implemented
|
||||
SYMBOL_UNDEFINE = 0xFF # Undefined var and function
|
||||
|
||||
# ops map: {op.type:(Namespace, symbol)}
|
||||
# Some space set aside for readability of code
|
||||
parse_object_map = {
|
||||
# ast grammar
|
||||
ast.Add: (trope_ns, 'add'),
|
||||
ast.Sub: (trope_ns, 'sub'),
|
||||
ast.Mult: (trope_ns, 'mul'),
|
||||
ast.Div: (trope_ns, 'truediv'),
|
||||
ast.FloorDiv: (trope_ns, 'floordiv'),
|
||||
ast.Mod: (trope_ns, 'mod'),
|
||||
ast.Pow: (trope_ns, 'pow'),
|
||||
ast.MatMult: (trope_ns, 'matmul'),
|
||||
ast.LShift: (trope_ns, 'lshift'),
|
||||
ast.RShift: (trope_ns, 'rshift'),
|
||||
ast.And: (trope_ns, 'and_'),
|
||||
ast.Or: (trope_ns, 'or_'),
|
||||
ast.BitXor: (trope_ns, 'xor'),
|
||||
ast.UAdd: (trope_ns, 'pos'),
|
||||
ast.USub: (trope_ns, 'neg'),
|
||||
ast.Invert: (trope_ns, 'invert'),
|
||||
ast.Not: (trope_ns, 'not_'),
|
||||
ast.Eq: (trope_ns, 'eq'),
|
||||
ast.NotEq: (trope_ns, 'ne'),
|
||||
ast.Lt: (trope_ns, 'lt'),
|
||||
ast.Gt: (trope_ns, 'gt'),
|
||||
ast.LtE: (trope_ns, 'le'),
|
||||
ast.GtE: (trope_ns, 'ge'),
|
||||
ast.Is: (trope_ns, 'is_'),
|
||||
ast.IsNot: (trope_ns, 'is_not'),
|
||||
ast.In: (trope_ns, 'contains'),
|
||||
ast.NotIn: (trope_ns, 'not_contains'),
|
||||
|
||||
# operation symbol type
|
||||
'getitem': (composite_ns, 'getitem'),
|
||||
'ms_iter': (composite_ns, 'ms_iter'),
|
||||
'ms_next': (composite_ns, 'ms_next'),
|
||||
'hasnext': (composite_ns, 'hasnext'),
|
||||
|
||||
# undefined type
|
||||
SYMBOL_UNDEFINE: (None, 'undefine'),
|
||||
}
|
||||
|
||||
# convert map: {obj:(Namespace, symbol)}
|
||||
# Escape an object to another object, eg: system function(len,xxx)
|
||||
# Some space set aside for readability of code
|
||||
convert_object_map = {
|
||||
T.add: multitype_ops.add,
|
||||
T.sub: multitype_ops.sub,
|
||||
T.mul: multitype_ops.mul,
|
||||
T.truediv: multitype_ops.div,
|
||||
T.getitem: multitype_ops.getitem,
|
||||
T.floordiv: NO_IMPLEMENT,
|
||||
T.mod: F.scalar_mod,
|
||||
T.pow: F.scalar_pow,
|
||||
T.matmul: F.dot,
|
||||
T.lshift: NO_IMPLEMENT,
|
||||
T.rshift: NO_IMPLEMENT,
|
||||
T.and_: multitype_ops.logical_and,
|
||||
T.or_: multitype_ops.logical_or,
|
||||
T.xor: NO_IMPLEMENT,
|
||||
T.pos: F.scalar_uadd,
|
||||
T.neg: multitype_ops.negative,
|
||||
T.invert: NO_IMPLEMENT,
|
||||
T.not_: F.bool_not,
|
||||
T.eq: multitype_ops.equal,
|
||||
T.ne: F.scalar_ne,
|
||||
T.lt: multitype_ops.less,
|
||||
T.gt: F.scalar_gt,
|
||||
T.le: multitype_ops.less_equal,
|
||||
T.ge: F.scalar_ge,
|
||||
T.is_: F.is_,
|
||||
T.is_not: F.is_not,
|
||||
T.contains: NO_IMPLEMENT,
|
||||
T.not_contains: NO_IMPLEMENT,
|
||||
|
||||
# system function
|
||||
T.len: M.ms_len,
|
||||
T.bool: M.bool_,
|
||||
T.map: C.HyperMap(),
|
||||
T.partial: F.partial,
|
||||
T.zip: C.zip_operation,
|
||||
|
||||
# custom define operation
|
||||
T.iter: M.ms_iter,
|
||||
T.next: M.ms_next,
|
||||
T.hasnext: M.hasnext,
|
||||
T.setitem: M.setitem,
|
||||
|
||||
T.make_tuple: F.make_tuple,
|
||||
T.make_dict: F.make_dict,
|
||||
T.make_list: F.make_list,
|
||||
T.make_slice: F.make_slice,
|
||||
T.range: F.make_range,
|
||||
|
||||
# lib function
|
||||
math.floor: NO_IMPLEMENT,
|
||||
math.trunc: NO_IMPLEMENT,
|
||||
math.exp: NO_IMPLEMENT,
|
||||
math.log: F.scalar_log,
|
||||
math.sin: NO_IMPLEMENT,
|
||||
math.cos: NO_IMPLEMENT,
|
||||
math.tan: NO_IMPLEMENT,
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""The functions in this file is used to dump and load python object in anf graphs."""
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import stat
|
||||
|
||||
|
||||
def dump_obj(obj, path):
|
||||
"""Dump object to file."""
|
||||
|
||||
file_name = hex(id(obj))
|
||||
file_path = path + file_name
|
||||
with open(file_path, 'wb') as f:
|
||||
os.chmod(file_path, stat.S_IWUSR | stat.S_IRUSR)
|
||||
pickle.dump(obj, f)
|
||||
return file_name
|
||||
|
||||
|
||||
def load_obj(file_path):
|
||||
"""Load object from file."""
|
||||
obj = None
|
||||
try:
|
||||
real_file_path = os.path.realpath(file_path)
|
||||
except Exception as ex:
|
||||
raise RuntimeError(ex)
|
||||
with open(real_file_path, 'rb') as f:
|
||||
obj = pickle.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
__all__ = ['dump_obj', 'load_obj']
|
|
@ -0,0 +1,235 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""standard_method"""
|
||||
from dataclasses import dataclass
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like
|
||||
from ...ops.composite.base import _append
|
||||
|
||||
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
|
||||
|
||||
trans = P.Transpose()
|
||||
|
||||
|
||||
def transpose(x):
|
||||
"""Implementation of `transpose`."""
|
||||
shape = F.shape(x)
|
||||
length = F.tuple_len(shape)
|
||||
perm = F.make_range(0, length)
|
||||
revert_perm = F.tuple_reversed(perm)
|
||||
out = trans(x, revert_perm)
|
||||
return out
|
||||
|
||||
|
||||
def getitem(data, item):
|
||||
"""Implementation of `getitem`."""
|
||||
return data.__getitem__(item)
|
||||
|
||||
|
||||
def setitem(data, item, value):
|
||||
"""Implementation of `setitem`."""
|
||||
return data.__setitem__(item, value)
|
||||
|
||||
|
||||
def ms_iter(xs):
|
||||
"""Implementation of `iter`."""
|
||||
return xs.__ms_iter__()
|
||||
|
||||
|
||||
def ms_next(it):
|
||||
"""Implementation of `next`."""
|
||||
return it.__ms_next__()
|
||||
|
||||
|
||||
def hasnext(it):
|
||||
"""Implementation of `hasnext`."""
|
||||
return it.__ms_hasnext__()
|
||||
|
||||
|
||||
def ms_len(data):
|
||||
"""Implementation of `len`."""
|
||||
return data.__len__()
|
||||
|
||||
|
||||
def floor(x):
|
||||
"""Implementation of `floor`."""
|
||||
return x.__floor__()
|
||||
|
||||
|
||||
def trunc(x):
|
||||
"""Implementation of `trunc`."""
|
||||
return x.__trunc__()
|
||||
|
||||
|
||||
def uadd(x):
|
||||
"""Implementation of `uadd`."""
|
||||
return x.__pos__()
|
||||
|
||||
|
||||
def usub(x):
|
||||
"""Implementation of `usub`."""
|
||||
return x.__neg__()
|
||||
|
||||
|
||||
def scalar_truediv(x, y):
|
||||
"""Implementation of `scalar_truediv`."""
|
||||
return x.__truediv__(y)
|
||||
|
||||
|
||||
def scalar_floordiv(x, y):
|
||||
"""Implementation of `scalar_floordiv`."""
|
||||
return x.__floordiv__(y)
|
||||
|
||||
|
||||
def bool_(x):
|
||||
"""Implementation of `bool`."""
|
||||
return x.__bool__()
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
"""return immedate x, x is a tensor of bool value"""
|
||||
return x
|
||||
|
||||
|
||||
def and_(x, y):
|
||||
"""Implementation of `and` (`&`)."""
|
||||
return x.__and__(y)
|
||||
|
||||
|
||||
def or_(x, y):
|
||||
"""Implementation of `or` (`|`)."""
|
||||
return x.__or__(y)
|
||||
|
||||
|
||||
def matmul(x, y):
|
||||
"""Implementation of `matmul` (`@`)."""
|
||||
return x.__matmul__(y)
|
||||
|
||||
|
||||
def float_bool(x):
|
||||
"""Implementation of `float_bool`."""
|
||||
return x != 0.0
|
||||
|
||||
|
||||
def int_bool(x):
|
||||
"""Implementation of `int_bool`."""
|
||||
return x != 0
|
||||
|
||||
|
||||
def str_bool(x):
|
||||
"""Implementation of `str_bool`."""
|
||||
if x == "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def list_bool(x):
|
||||
"""Implementation of `tuple_bool`."""
|
||||
return len(x) != 0
|
||||
|
||||
|
||||
def tuple_bool(x):
|
||||
"""Implementation of `tuple_bool`."""
|
||||
return len(x) != 0
|
||||
|
||||
|
||||
def dict_bool(x):
|
||||
"""Implementation of `dict_bool`."""
|
||||
return len(x) != 0
|
||||
|
||||
|
||||
def none_bool(x):
|
||||
"""Implementation of `none_bool`."""
|
||||
return False
|
||||
|
||||
|
||||
def float_floordiv(x, y):
|
||||
"""Implementation of `float_floordiv`."""
|
||||
return floor(x / y)
|
||||
|
||||
|
||||
#############
|
||||
# Iteration #
|
||||
#############
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SequenceIterator:
|
||||
"""
|
||||
SequenceIterator is a util dataclass for iterating sequence object.
|
||||
|
||||
Iterator to use for sequences like List, Array.
|
||||
"""
|
||||
|
||||
idx: int
|
||||
seq: list
|
||||
|
||||
@core(ignore_values=True)
|
||||
def __ms_hasnext__(self):
|
||||
"""Whether the index is past the length of the sequence."""
|
||||
return self.idx < ms_len(self.seq)
|
||||
|
||||
@core(ignore_values=True)
|
||||
def __ms_next__(self):
|
||||
"""Return the next element and a new iterator."""
|
||||
return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
|
||||
|
||||
|
||||
def list_iter(xs):
|
||||
"""Iterator for List."""
|
||||
return SequenceIterator(0, xs)
|
||||
|
||||
|
||||
def array_iter(xs):
|
||||
"""Iterator for Array."""
|
||||
return SequenceIterator(0, xs)
|
||||
|
||||
|
||||
def tuple_next(xs):
|
||||
"""Next tuple."""
|
||||
return xs[0], tail(xs)
|
||||
|
||||
|
||||
def tuple_hasnext(xs):
|
||||
"""Whether the tuple is empty or not."""
|
||||
return len(xs) > 0
|
||||
|
||||
|
||||
def list_next(xs):
|
||||
"""Next list."""
|
||||
return xs[0], tail(xs)
|
||||
|
||||
|
||||
def list_hasnext(xs):
|
||||
"""Whether the list is empty or not."""
|
||||
return len(xs) > 0
|
||||
|
||||
|
||||
def list_append(self_, item):
|
||||
return _append(self_, item)
|
||||
|
||||
|
||||
#################
|
||||
# Array methods #
|
||||
#################
|
||||
|
||||
|
||||
def to_array(x):
|
||||
"""Implementation of `to_array`."""
|
||||
return x.__ms_to_array__()
|
|
@ -0,0 +1,93 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Trope some system function symbol to ops."""
|
||||
|
||||
# This operation function is not meant to be called directly
|
||||
|
||||
# support operator symbol, ast
|
||||
from operator import ( # noqa
|
||||
add, sub, mul, truediv, floordiv, mod, eq, ne, lt, gt, le, ge, pos, neg,
|
||||
not_, and_, or_, xor, lshift, rshift, invert, is_, is_not, contains,
|
||||
matmul, getitem, setitem
|
||||
)
|
||||
|
||||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip
|
||||
)
|
||||
|
||||
# support functools
|
||||
from functools import ( # noqa
|
||||
partial
|
||||
)
|
||||
|
||||
# support numpy symbol
|
||||
from numpy import ( # noqa
|
||||
exp, log, sin, cos, tan
|
||||
)
|
||||
|
||||
__all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'gt', 'le', 'ge', 'pos', 'neg',
|
||||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial',
|
||||
'exp', 'log', 'sin', 'cos', 'tan']
|
||||
|
||||
|
||||
def make_tuple(*elts): # pragma: no cover
|
||||
"""Tuple builder."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def make_dict(key, value): # pragma: no cover
|
||||
"""Dict builder."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def make_list(*elts): # pragma: no cover
|
||||
"""List builder."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def make_slice(*elts): # pragma: no cover
|
||||
"""Slice builder."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def make_range(*elts): # pragma: no cover
|
||||
"""Range tuple builder."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def switch(cond, tb, fb): # pragma: no cover
|
||||
"""Switch statement, returns one of the two values."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def hasnext(it): # pragma: no cover
|
||||
"""Hasnext function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def to_array(x): # pragma: no cover
|
||||
"""The to_array function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
|
||||
def not_contains(x): # pragma: no cover
|
||||
"""Not in function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Pynative mode help module."""
|
||||
from inspect import signature
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def args_type_check(*type_args, **type_kwargs):
|
||||
"""Check whether input data type is correct."""
|
||||
|
||||
def type_check(func):
|
||||
sig = signature(func)
|
||||
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal bound_types
|
||||
bound_values = sig.bind(*args, **kwargs)
|
||||
argument_dict = bound_values.arguments
|
||||
if "kwargs" in bound_types:
|
||||
bound_types = bound_types["kwargs"]
|
||||
if "kwargs" in argument_dict:
|
||||
argument_dict = argument_dict["kwargs"]
|
||||
for name, value in argument_dict.items():
|
||||
if name in bound_types:
|
||||
if value is not None and not isinstance(value, bound_types[name]):
|
||||
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return type_check
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Some utils."""
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def cal_sha256(file_path):
|
||||
"""
|
||||
Calculate sha256 value of input file.
|
||||
|
||||
Args:
|
||||
file_path (str): file path to calculate.
|
||||
|
||||
Returns:
|
||||
str, returns sha256 value of input file or empty str if cal failed.
|
||||
"""
|
||||
buf_size = 64 * 1024 # once read 64kb
|
||||
sha256 = hashlib.sha256()
|
||||
file_real_path = os.path.realpath(file_path)
|
||||
if not os.path.isfile(file_real_path) or not os.access(file_real_path, os.R_OK):
|
||||
return ""
|
||||
_, extension = os.path.splitext(file_path)
|
||||
try:
|
||||
if extension == '.ptx':
|
||||
with open(file_real_path, 'r') as kf:
|
||||
while True:
|
||||
data = kf.read(buf_size)
|
||||
if not data:
|
||||
break
|
||||
sha256.update(data.encode("utf-8"))
|
||||
else:
|
||||
with open(file_real_path, 'rb') as kf:
|
||||
while True:
|
||||
data = kf.read(buf_size)
|
||||
if not data:
|
||||
break
|
||||
sha256.update(data)
|
||||
except IOError:
|
||||
logging.error("Open file %s failed.", file_path)
|
||||
return ""
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
def cell_attr_register(fn=None, attrs=None):
|
||||
"""
|
||||
Cell init attributes register.
|
||||
|
||||
Registering the decorator of the built-in operator cell __init__
|
||||
function will add save all the parameters of __init__ as operator attributes.
|
||||
|
||||
Args:
|
||||
fn (function): __init__ function of cell.
|
||||
attrs (list(string) | string): attr list.
|
||||
|
||||
Returns:
|
||||
function, original function.
|
||||
"""
|
||||
def wrap_cell(fn):
|
||||
@wraps(fn)
|
||||
def deco(self, *args, **kwargs):
|
||||
arguments = []
|
||||
if attrs is None:
|
||||
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
||||
arguments = bound_args.arguments
|
||||
del arguments['self']
|
||||
arguments = arguments.values()
|
||||
fn(self, *args, **kwargs)
|
||||
if attrs is not None:
|
||||
if isinstance(attrs, list):
|
||||
for item in attrs:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"attr must be a string")
|
||||
if hasattr(self, item):
|
||||
arguments.append(getattr(self, item))
|
||||
elif isinstance(attrs, str):
|
||||
if hasattr(self, attrs):
|
||||
arguments = getattr(self, attrs)
|
||||
else:
|
||||
raise ValueError(f"attrs must be list or string")
|
||||
self.cell_init_args = type(self).__name__ + str(arguments)
|
||||
return deco
|
||||
if fn is not None:
|
||||
return wrap_cell(fn)
|
||||
return wrap_cell
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""__init__"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import sys
|
||||
import os
|
||||
|
||||
def AKGAddPath():
|
||||
"""akg add path."""
|
||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||
tvm_path = os.path.realpath(pwd)
|
||||
if tvm_path not in sys.path:
|
||||
sys.path.insert(0, tvm_path)
|
||||
else:
|
||||
sys.path.remove(tvm_path)
|
||||
sys.path.insert(0, tvm_path)
|
||||
|
||||
|
||||
class AKGMetaPathFinder:
|
||||
"""class AKGMetaPath finder."""
|
||||
|
||||
def find_module(self, fullname, path=None):
|
||||
"""method akg find module."""
|
||||
if fullname.startswith("akg.tvm"):
|
||||
rname = fullname[4:]
|
||||
return AKGMetaPathLoader(rname)
|
||||
if fullname.startswith("akg.topi"):
|
||||
rname = fullname[4:]
|
||||
return AKGMetaPathLoader(rname)
|
||||
return None
|
||||
|
||||
|
||||
class AKGMetaPathLoader:
|
||||
"""class AKGMetaPathLoader loader."""
|
||||
def __init__(self, rname):
|
||||
self.__rname = rname
|
||||
|
||||
def load_module(self, fullname):
|
||||
if self.__rname in sys.modules:
|
||||
sys.modules.pop(self.__rname)
|
||||
AKGAddPath()
|
||||
__import__(self.__rname, globals(), locals())
|
||||
self.__target_module = sys.modules[self.__rname]
|
||||
sys.modules[fullname] = self.__target_module
|
||||
return self.__target_module
|
||||
|
||||
|
||||
sys.meta_path.insert(0, AKGMetaPathFinder())
|
||||
|
||||
from .op_build import op_build
|
||||
from .message import compilewithjson
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""__init__"""
|
||||
from .equal import Equal
|
||||
from .equal import gpu_schedule_Equal
|
||||
from .tile import Tile
|
||||
from .tile import gpu_schedule_Tile
|
||||
from .cast import Cast
|
||||
from .cast import gpu_schedule_Cast
|
||||
from .relu6 import ReLU6, gpu_schedule_ReLU6
|
||||
from .relu6_grad import ReLU6Grad, gpu_schedule_ReLU6Grad
|
||||
from .squeeze import Squeeze, gpu_schedule_Squeeze
|
||||
from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad
|
||||
from .mean import SimpleMean, gpu_schedule_SimpleMean
|
||||
from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad
|
||||
from .mul import Mul, gpu_schedule_Mul
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""cast"""
|
||||
import logging
|
||||
import akg.tvm
|
||||
from akg.ops.math import cast
|
||||
from akg.topi.generic import schedule_elemwise
|
||||
|
||||
def Cast(x, dst_type):
|
||||
"""cast."""
|
||||
return cast.cast(x, dst_type)
|
||||
|
||||
|
||||
def gpu_schedule_Cast(outs):
|
||||
"""
|
||||
gpu schedule for cast.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = akg.tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
logging.info("Skip because %s is not enabled", device)
|
||||
return None
|
||||
with akg.tvm.target.create(device):
|
||||
sch = schedule_elemwise(outs)
|
||||
return sch
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""default schedule function for GPU"""
|
||||
from queue import Queue
|
||||
|
||||
import akg.tvm as tvm
|
||||
|
||||
DEFAULT_GPU_THREAD = 1024
|
||||
|
||||
|
||||
def default_schedule(outs):
|
||||
"""
|
||||
default schedule function.
|
||||
|
||||
Args:
|
||||
outs (Union[tvm.tensor.Tensor, list[tvm.tensor.Tensor]]): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
if not isinstance(outs, tvm.tensor.Tensor) and not isinstance(outs, list):
|
||||
raise ValueError("outs should be list of akg.tvm.tensor.Tensor or akg.tvm.tensor.Tensor")
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
outs_list = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
with tvm.target.create(device):
|
||||
sch = tvm.create_schedule(outs_list[0].op)
|
||||
outputs_tensor = Queue()
|
||||
outputs_tensor.put(outs_list[0])
|
||||
op_list = []
|
||||
while not outputs_tensor.empty():
|
||||
out = outputs_tensor.get()
|
||||
if out.op not in op_list and isinstance(out.op, tvm.tensor.ComputeOp):
|
||||
op_list.append(out.op)
|
||||
for input_tensor in out.op.input_tensors:
|
||||
outputs_tensor.put(input_tensor)
|
||||
for op in op_list:
|
||||
stage = sch[op.output(0)]
|
||||
bx, tx = stage.split(op.axis[0], factor=DEFAULT_GPU_THREAD)
|
||||
stage.bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
stage.bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
return sch
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""equal"""
|
||||
import akg.tvm
|
||||
from akg.ops.math import equal
|
||||
from akg.topi.generic import schedule_elemwise
|
||||
|
||||
def Equal(x, y):
|
||||
"""equal."""
|
||||
return equal.equal(x, y)
|
||||
|
||||
|
||||
def gpu_schedule_Equal(outs):
|
||||
"""
|
||||
gpu schedule for Equal.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = akg.tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with akg.tvm.target.create(device):
|
||||
sch = schedule_elemwise(outs)
|
||||
return sch
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""mean op compute and schedule"""
|
||||
import akg.tvm as tvm
|
||||
from akg.ops.math.mean import mean
|
||||
from .default_schedule import DEFAULT_GPU_THREAD
|
||||
|
||||
def Mean(x, axis=None, keepdims=True):
|
||||
"""mean."""
|
||||
outs = mean(x, axis, keepdims)
|
||||
|
||||
# remove useless mean_output
|
||||
if isinstance(outs, tuple):
|
||||
outs = outs[0]
|
||||
if outs.op.name == "mean_output":
|
||||
outs = outs.op.input_tensors[0]
|
||||
return outs
|
||||
|
||||
|
||||
def gpu_schedule_Mean(outs):
|
||||
"""
|
||||
gpu schedule function for mean.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
out = outs[0] if isinstance(outs, list) else outs
|
||||
|
||||
device = "cuda"
|
||||
with tvm.target.create(device):
|
||||
sch = tvm.create_schedule(out.op)
|
||||
if out.op.name == "T_divide":
|
||||
tensor_c = out
|
||||
else: # squeeze
|
||||
tensor_c = out.op.input_tensors[0]
|
||||
|
||||
tensor_b = tensor_c.op.input_tensors[0]
|
||||
if len(tensor_c.op.axis) >= 2:
|
||||
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1])
|
||||
else:
|
||||
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0])
|
||||
|
||||
bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD)
|
||||
sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
return sch
|
||||
|
||||
def SimpleMean(x):
|
||||
"""
|
||||
SimpleMean compute the mean of the input 4D Tensor over last two axises and keep reduced dimensions.
|
||||
|
||||
Args:
|
||||
x (tvm.tensor.Tensor): Tensor of type float16, float32.
|
||||
|
||||
Returns:
|
||||
tvm.tensor.Tensor, has the same type as x, output shape will be (a, b, 1, 1) if input Tensor x is (a, b, c, d).
|
||||
"""
|
||||
axis = (2, 3)
|
||||
keepdims = True
|
||||
return Mean(x, axis, keepdims)
|
||||
|
||||
|
||||
def gpu_schedule_SimpleMean(outs):
|
||||
"""gpu schedule function for SimpleMean."""
|
||||
return gpu_schedule_Mean(outs)
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""mean_grad"""
|
||||
import akg.tvm as tvm
|
||||
import akg
|
||||
from akg.ops.math import mean
|
||||
from .default_schedule import DEFAULT_GPU_THREAD
|
||||
|
||||
|
||||
def mean_ad(head, input_shape, axis, keepdims):
|
||||
"""mean autodiff."""
|
||||
tensor_a = tvm.placeholder(input_shape, head.dtype, "A")
|
||||
tensor_b = mean.mean(tensor_a, axis, keepdims)
|
||||
|
||||
# remove useless mean_output
|
||||
if isinstance(tensor_b, tuple):
|
||||
tensor_b = tensor_b[0]
|
||||
if tensor_b.op.name == "mean_output":
|
||||
tensor_b = tensor_b.op.input_tensors[0]
|
||||
|
||||
jacs = list(akg.differentiate(tensor_b, [tensor_a], head))
|
||||
return jacs[0]
|
||||
|
||||
|
||||
def MeanGrad(y_grad, input_shape, axis=None, keepdims=True):
|
||||
"""Mean Grad."""
|
||||
if axis is None and not keepdims:
|
||||
raise ValueError("Mean not support (axis=None && keepdims=False) now")
|
||||
return mean_ad(y_grad, input_shape, axis, keepdims)
|
||||
|
||||
|
||||
def gpu_schedule_MeanGrad(outs):
|
||||
"""gpu schedule MeanGrad."""
|
||||
out = outs[0] if isinstance(outs, list) else outs
|
||||
|
||||
device = "cuda"
|
||||
with tvm.target.create(device):
|
||||
sch = tvm.create_schedule(out.op)
|
||||
tensor_c = out
|
||||
tensor_b = tensor_c.op.input_tensors[0]
|
||||
if len(tensor_c.op.axis) >= 2:
|
||||
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1])
|
||||
else:
|
||||
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0])
|
||||
|
||||
bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD)
|
||||
sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
|
||||
return sch
|
||||
|
||||
def SimpleMeanGrad(HEAD, input_shape):
|
||||
"""
|
||||
Compute Simple Mean Grad.
|
||||
|
||||
Args:
|
||||
HEAD (tvm.tensor.Tensor): output gradient, dy, defined in Primitive.
|
||||
input_shape (Union[list[int], tuple[int]]): shape of mean input, x.shape.
|
||||
|
||||
Returns:
|
||||
tvm.tensor.Tensor, gradient of mean input.
|
||||
"""
|
||||
axis = (2, 3)
|
||||
keepdims = True
|
||||
return MeanGrad(HEAD, input_shape, axis, keepdims)
|
||||
|
||||
|
||||
def gpu_schedule_SimpleMeanGrad(outs):
|
||||
"""
|
||||
gpu schedule SimpleMeanGrad.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
return gpu_schedule_MeanGrad(outs)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""mul"""
|
||||
import akg.topi as topi
|
||||
import akg.tvm as tvm
|
||||
from akg.ops.math import mul
|
||||
|
||||
def Mul(x, y):
|
||||
"""mul."""
|
||||
return mul.mul(x, y)
|
||||
|
||||
|
||||
def gpu_schedule_Mul(outs):
|
||||
"""
|
||||
gpu schedule for mul.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_broadcast(outs)
|
||||
return sch
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""relu6"""
|
||||
import akg.topi as topi
|
||||
import akg.tvm as tvm
|
||||
from akg.topi import tag
|
||||
|
||||
@tvm.tag_scope(tag=tag.ELEMWISE)
|
||||
def topi_nn_relu6(x):
|
||||
"""topi nn relu6."""
|
||||
return tvm.compute(x.shape, lambda *i: tvm.min(tvm.max(x(*i), tvm.const(0, x.dtype)), tvm.const(6, x.dtype)))
|
||||
|
||||
def ReLU6(x):
|
||||
"""
|
||||
Compute elementwise with function: min(max(x, 0), 6).
|
||||
|
||||
Args:
|
||||
x (tvm.tensor.Tensor): Tensor of type float16, float32.
|
||||
|
||||
Returns:
|
||||
tvm.tensor.Tensor, has same type and shape as input.
|
||||
"""
|
||||
return topi_nn_relu6(x)
|
||||
|
||||
|
||||
def gpu_schedule_ReLU6(outs):
|
||||
"""
|
||||
gpu schedule ReLU6.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_elemwise(outs)
|
||||
return sch
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""relu6 grad"""
|
||||
import akg.topi as topi
|
||||
import akg.tvm as tvm
|
||||
|
||||
def ReLU6Grad(y_grad, x):
|
||||
"""
|
||||
Computes Gradients of Rectified Linear 6.
|
||||
|
||||
Args:
|
||||
y_grad (tvm.tensor.Tensor): Tensor of type float16, float32, gradients backpropagated to the ReLU6 op.
|
||||
x (tvm.tensor.Tensor): Tensor of type float16/float32, inputs that where passed to the ReLU6 op, or its outputs.
|
||||
|
||||
Returns:
|
||||
tvm.tensor.Tensor, has same type and shape as x.
|
||||
"""
|
||||
shape = x.shape
|
||||
dtype = x.dtype
|
||||
|
||||
zero = tvm.const(0, dtype)
|
||||
six = tvm.const(6, dtype)
|
||||
|
||||
res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= zero, x(*i), zero))
|
||||
res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= six, zero, res0(*i)))
|
||||
res = tvm.compute(shape, lambda *i: tvm.if_then_else(res6(*i) == zero, zero, y_grad(*i)))
|
||||
return res
|
||||
|
||||
|
||||
def gpu_schedule_ReLU6Grad(outs):
|
||||
"""
|
||||
gpu schedule ReLU6Grad.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_elemwise(outs)
|
||||
return sch
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue